Skip to content

Commit

Permalink
#8340: Add functional grouped convolution support
Browse files Browse the repository at this point in the history
  • Loading branch information
tapspatel committed May 24, 2024
1 parent 60b4b7c commit 49554c7
Show file tree
Hide file tree
Showing 8 changed files with 759 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(conv_params, i
output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups = [
conv_params[i] for i in range(10)
]
assert dilation == 1 and groups == 1
assert dilation == 1
assert len(input_nchw_shape) == 4
input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)]
# image 1 data
Expand Down
294 changes: 294 additions & 0 deletions tests/ttnn/unit_tests/operations/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tt_lib
import math
import os
import torch.nn as nn


# def plot_diff(vals, fid, nsticks, stick_len):
Expand Down Expand Up @@ -88,6 +89,7 @@ def run_conv(
fp32_accum=False,
packer_l1_acc=False,
output_layout=ttnn.TILE_LAYOUT,
groups=1,
):
# has_bias = False
has_bias = True
Expand Down Expand Up @@ -1267,3 +1269,295 @@ def test_conv_core_nondivis(
use_1d_systolic_array,
config_override,
)


# The following test takes various shape sizes from resnet50, unet and stable diffusion and tests for different number of groups - all the way to num_groups = num_in_channels (depthwise conv)
@skip_for_grayskull()
@pytest.mark.parametrize("device_l1_small_size", [16384], indirect=True)
@pytest.mark.parametrize(
"batch_size, in_channels, out_channels, input_height, input_width, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, num_groups, use_1d_systolic_array",
(
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, 1, True),
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, 2, True),
(1, 64, 64, 4, 4, 3, 3, 1, 1, 1, 1, 4, True),
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, 8, True),
(1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, 16, True),
(1, 64, 64, 2, 2, 3, 3, 1, 1, 1, 1, 32, True),
(1, 64, 64, 32, 32, 3, 3, 1, 1, 1, 1, 64, True),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 1, True),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 2, True),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True),
(1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True),
(8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 64, True),
(4, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, True),
(8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 128, True),
(8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False),
# (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 256, False), # doesn't fit with bfloat16 weights
# (32, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False), # doesn't fit with bfloat16 weights
(32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 40, False),
(32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 10, False),
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True),
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 16, True),
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 32, True),
(8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 2, False),
(8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 4, False),
# (1, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, 8, False),
(1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, 2, False),
(1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, 320, False),
# (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, False), # doesn't fit with bfloat16 weights
(2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, 32, True),
(2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, 2, True),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
def test_conv_groups(
device,
use_program_cache,
batch_size,
in_channels,
out_channels,
input_height,
input_width,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
num_groups,
use_1d_systolic_array,
weights_dtype,
):
# Test parameters
kernel_size = (kernel_height, kernel_width)
stride = (stride_h, stride_w)
padding = (pad_h, pad_w)
math_fidelity = ttnn.MathFidelity.HiFi4

# Create wormhole kernel configuration
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=math_fidelity,
math_approx_mode=True,
fp32_dest_acc_en=True,
packer_l1_acc=True,
)

# Torch implementation - run the nn.Conv2d version as the golden tensor
# Define original tensors and shapes
input_shape = [batch_size, in_channels, input_height, input_width]
weight_shape = [
out_channels,
in_channels // num_groups,
kernel_size[0],
kernel_size[1],
]
bias_shape = [1, 1, 1, out_channels]

# Define various tensors
torch_input_tensor_nchw = torch.randn(input_shape, dtype=torch.bfloat16).float()
torch_input_tensor_nhwc = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(weight_shape, dtype=torch.bfloat16).float()
torch_bias_tensor = torch.randn(bias_shape, dtype=torch.bfloat16).float()

# Define pytorch convolutional layer
torch_conv_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=num_groups,
)
torch_conv_layer.weight = nn.Parameter(torch_weight_tensor)
torch_conv_layer.bias = nn.Parameter(torch_bias_tensor.reshape(-1))

# Apply convolution operation
torch_output_tensor_nchw = torch_conv_layer(torch_input_tensor_nchw)

# TTNN implementation - run the ttnn.Conv2d version as the experimental tensor
# Define ttnn tensors and shapes
ttnn_weight_tensor = ttnn.from_torch(
torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)
ttnn_bias_tensor = ttnn.from_torch(
torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)

# Define ttnn convolution operation
ttnn_conv_layer = ttnn.Conv2d(
device=device,
batch_size=batch_size,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
input_height=input_height,
input_width=input_width,
math_fidelity=math_fidelity,
dtype=ttnn.bfloat16,
weights_dtype=weights_dtype,
deallocate_activation=False,
use_shallow_conv_variant=False,
use_1d_systolic_array=use_1d_systolic_array,
reader_patterns_cache={},
weight=ttnn_weight_tensor,
bias=ttnn_bias_tensor,
conv_blocking_and_parallelization_config_override=None,
enable_auto_formatting=False,
padded_input_channels=None,
compute_kernel_config=compute_kernel_config,
output_layout=ttnn.TILE_LAYOUT,
groups=num_groups,
)

# Convert torch input tensor to ttnn tensor
ttnn_input_tensor = ttnn.from_torch(torch_input_tensor_nhwc, ttnn.bfloat16)

# Move input tensor to device
ttnn_input_tensor_on_device = ttnn_conv_layer.copy_input_to_device(ttnn_input_tensor)

# Apply convolution operation on device
ttnn_output_tensor_on_device_tile_layout = ttnn_conv_layer(ttnn_input_tensor_on_device)

# Convert output and get output from device
ttnn_output_tensor_on_device_row_layout = ttnn.to_layout(
ttnn_output_tensor_on_device_tile_layout, ttnn.ROW_MAJOR_LAYOUT
)
ttnn_output_tensor = ttnn.from_device(ttnn_output_tensor_on_device_row_layout)
torch_ttnn_output_tensor = ttnn.to_torch(ttnn_output_tensor)

# Shape manipulations to ensure golden and experimental tensor are of the same shape
output_shape_nhwc = [
torch_output_tensor_nchw.shape[0],
torch_output_tensor_nchw.shape[2],
torch_output_tensor_nchw.shape[3],
torch_output_tensor_nchw.shape[1],
]
torch_ttnn_output_tensor_nhwc = torch.reshape(torch_ttnn_output_tensor, output_shape_nhwc)
torch_ttnn_output_tensor_nchw = torch.permute(torch_ttnn_output_tensor_nhwc, (0, 3, 1, 2))

passing, pcc_msg = assert_with_pcc(torch_output_tensor_nchw, torch_ttnn_output_tensor_nchw, pcc=0.99)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant, groups",
(
# yolov4 convs with batch size 1
# unique convs in yolov4 (complete list) # groups: number
# (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32
# (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
(1, 128, 128, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 2), # groups: 512
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
# [ttnn.bfloat8_b, ttnn.bfloat16],
[ttnn.bfloat8_b],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
# @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT])
def test_yolov4_conv_groups_larger_than_one(
device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant,
groups,
output_layout,
):
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")
if output_layout == ttnn.ROW_MAJOR_LAYOUT and input_height >= 1056:
pytest.skip("OOM")
run_conv(
device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
groups=groups,
padded_input_channels=16 if input_channels == 3 else None,
output_layout=output_layout,
)
Loading

0 comments on commit 49554c7

Please sign in to comment.