In [15]:
# Installing needed package (triton)
! pip install triton



In [16]:
# Importing needed libraries
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import time

In [17]:
# Golbal Variables
DEVICE = "cuda"
C_OUT = 64
C_IN = 3
H = 1024
W = 1024
FH = 3
FW = 3

In [18]:
# Making torch tensors
tensor_I = torch.rand(1, C_IN, H, W, device=DEVICE) # Input, assuming that batch_size is one
tensor_F = torch.rand(C_OUT, C_IN, FH, FW , device=DEVICE) # Weights

In [19]:
# This is the result from Convolutional Layer provided by Torch
# Use this for correctness check
golden_out = F.conv2d(tensor_I, tensor_F, padding=1)
print(golden_out.shape) # (1, C_OUT, OUT_H, OUT_W)

torch.Size([1, 64, 1024, 1024])


In [20]:
@triton.jit
def my_triton_kernel(
    input_channels, input_height, input_width, num_filters, padding, filter_height, filter_width, input, filter, output,
    BLOCK_H: tl.constexpr, BLOCK_W: tl.constexpr
):
    """
  This is a triton kernel that does Conv assuming that padding=1, stride=1
  You should 1) load the values fro input and kernel, 2) does computation, 3) store the result
  """
  # TODO: Complete the triton kernel that does convolution
    tx = tl.program_id(0)
    ty = tl.program_id(1)
    tz = tl.program_id(2)

    result = 0.0
    for channel in range(0, input_channels):
        for row in range(filter_height):
            for col in range(filter_width):
                x = tx + col
                y = ty + row

                i_idx = channel * (input_height + 2 * padding) * (input_width + 2 * padding)
                i_idx += y * (input_width + 2 * padding)
                i_idx += x


                f_idx = tz * input_channels * filter_height * filter_width
                f_idx += channel * filter_height * filter_width
                f_idx += row * filter_width + col

                load_i = input + i_idx
                load_f = filter + f_idx

                result += tl.load(load_i) * tl.load(load_f)

    o_idx = tz * input_height * input_width
    o_idx += ty * input_width
    o_idx += tx
    out = output + o_idx
    tl.store(out, result)


def my_conv2d(input, kernel):
    """
    This function is a wrapper function that preprocess the inputs and call the kernel
    input: torch.tensor (1, C_IN, H, W)
    kernel: torch.tensor (C_OUT, C_IN, FH, FW)
    """

    # TODO: Initializing some variables

    nothing, input_channels, input_height, input_width = input.shape
    output_channels, nothing, filter_height, filter_width = kernel.shape
    padding = int(filter_height / 2)

    device = "cuda"

    input = input.to(dtype=torch.float32)
    kernel = kernel.to(dtype=torch.float32)

    input_with_padding = torch.nn.functional.pad(input, (padding, padding, padding, padding))

    # TODO: Calculate output dimension & Allocate output tensor

    output = torch.empty((1, output_channels, input_height, input_width), device=device, dtype=torch.float32)

    tri_i, tri_f, tri_o = input_with_padding.flatten(), kernel.flatten(), output.flatten()

    # TODO: Define grid

    grid = (input_width, input_height, output_channels)

    # TODO: Call the triton kernel (my_triton_kernel) and measure execution time

    start = time.time()
    my_triton_kernel[grid](
        input_channels, input_height, input_width, output_channels, padding,filter_height, filter_width, tri_i, tri_f, tri_o,
        BLOCK_H=1, BLOCK_W=1
    )
    # synchronize to make sure kernel is done
    torch.cuda.synchronize()
    end = time.time()
    exec_time_ms = (end - start)

    # TODO: Return output (output should include execution time)

    return output, exec_time_ms

In [21]:
# Testing
# Comparing the result from my_conv2d and Conv from torch
my_output, execution_time = my_conv2d(tensor_I, tensor_F)
torch.testing.assert_close(golden_out, my_output) # Assert statement should be passed
# Printing the execution time
print(f"Execution Time for triton kernel (ms): {execution_time * 1000:.3f}")

Execution Time for triton kernel (ms): 1031.304
