In [29]:
import collections 
import torch
import torch.nn as nn
import torch.nn.functional as F

input_channels = 4
input_length = 4
padded_input_length = 8
padding = 2
kernel_size = 5
rows_per_thread = 1
threads = input_channels//rows_per_thread
padded_input_length = (padding*2) + input_length
input_accesses_per_thread = (input_channels * input_length)//(threads*4)
weight_accesses_per_thread = (input_channels * kernel_size)//(threads)
d_input = [i+1 for i in range(input_channels*input_length)]
d_weights = [i+1 for i in range(input_channels*kernel_size)]
d_output = [0] * input_length
shared_mem = [-1 for _ in range(max(input_length, kernel_size)*input_channels)]
input_registers = collections.defaultdict(lambda : [0 for _ in range(rows_per_thread*padded_input_length)])
weight_registers = collections.defaultdict(lambda : [0 for _ in range(rows_per_thread*kernel_size)])


input_tensor = torch.tensor(d_input, dtype=torch.float32).view(1, input_channels, input_length)
weight_tensor = torch.tensor(d_weights, dtype=torch.float32).view(1, input_channels, kernel_size)
conv = nn.Conv1d(in_channels=input_channels, out_channels=1, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
conv.weight = nn.Parameter(weight_tensor)

In [30]:
for tdIdx in range(threads):    
    for rowIdx in range(input_accesses_per_thread):
        td_offset = 4 * (rowIdx*threads + tdIdx)
        shared_mem[td_offset + 0] = d_input[td_offset + 0]
        shared_mem[td_offset + 1] = d_input[td_offset + 1]
        shared_mem[td_offset + 2] = d_input[td_offset + 2]
        shared_mem[td_offset + 3] = d_input[td_offset + 3]

for tdIdx in range(threads):
    for rowIdx in range(rows_per_thread):
        for colIdx in range(0, input_length, 4):
            reg_index = padding + rowIdx*padded_input_length + colIdx
            shared_mem_index = input_length*(rows_per_thread*tdIdx + rowIdx) + colIdx 
            input_registers[tdIdx][reg_index + 0] = shared_mem[shared_mem_index + 0]
            input_registers[tdIdx][reg_index + 1] = shared_mem[shared_mem_index + 1]
            input_registers[tdIdx][reg_index + 2] = shared_mem[shared_mem_index + 2]
            input_registers[tdIdx][reg_index + 3] = shared_mem[shared_mem_index + 3]

for tdIdx in range(threads):
    for rowIdx in range(weight_accesses_per_thread):
        td_offset = (rowIdx*threads) + tdIdx 
        shared_mem[td_offset] = d_weights[td_offset]

for tdIdx in range(threads):
    for rowIdx in range(rows_per_thread):
        for colIdx in range(kernel_size):
            reg_idx = (kernel_size*rowIdx) + colIdx
            shared_mem_index = kernel_size*(rows_per_thread*tdIdx + rowIdx) + colIdx
            weight_registers[tdIdx][reg_idx] = shared_mem[shared_mem_index]
        
for tileIdx in range(input_length):        
    for tdIdx in range(threads):
        res = 0.0 
        for dotIdx in range(kernel_size):
            for rowIdx in range(rows_per_thread):
                res += input_registers[tdIdx][tileIdx + dotIdx + (padded_input_length*rowIdx)] * \
                weight_registers[tdIdx][dotIdx + (kernel_size*rowIdx)]
        shared_mem[tdIdx] = res
    d_output[tileIdx] = sum(shared_mem[:threads])

In [34]:
print(d_output)
d_output_tensor = torch.tensor(d_output, dtype=torch.float32).view(1, 1, -1)
output = conv(input_tensor)


print("Output:", output)

assert torch.allclose(d_output_tensor, output, rtol=1e-5, atol=0.001), "The outputs are not approximately equal"

print("The outputs are approximately equal within a margin of 10^-3.")

[1412.0, 1916.0, 1780.0, 1334.0]
Output: tensor([[[1412., 1916., 1780., 1334.]]], grad_fn=<ConvolutionBackward0>)
The outputs are approximately equal within a margin of 10^-3.
