In [86]:
import collections 
import torch
import torch.nn as nn
import torch.nn.functional as F

d_input = [i+1 for i in range(4096)]
d_weights = [i+1 for i in range(5120)]
d_output = [0] * 4
input_shared_mem = [-1 for _ in range(4096)]
weight_shared_mem = [-1 for _ in range(5120)]
sum_reduce_shared_mem = [0] * 256
input_registers = collections.defaultdict(lambda : [0 for _ in range(32)])
weight_registers = collections.defaultdict(lambda : [0 for _ in range(20)])
threads = 256
padded_input_length = 8
padding = 2
kernel_size = 5
output_length = 4

input_tensor = torch.tensor(d_input, dtype=torch.float32).view(1, 1024, 4)
weight_tensor = torch.tensor(d_weights, dtype=torch.float32).view(1, 1024, 5)
conv = nn.Conv1d(in_channels=1024, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
conv.weight = nn.Parameter(weight_tensor)

In [87]:
for i in range(4):
    for tdIdx in range(threads):
        td_offset = 4 * (i*threads + tdIdx)
        for v in range(4):
            input_shared_mem[td_offset+v] = d_input[td_offset+v]

for i in range(4):
    for tdIdx in range(threads):
        td_offset = (16*tdIdx) + (4*i)
        for v in range(4):
            input_registers[tdIdx][i*padded_input_length + padding + v] = input_shared_mem[td_offset+v]

In [88]:
for i in range(5):
    for tdIdx in range(threads):
        td_offset = 4 * (i*threads + tdIdx)
        for v in range(4):
            weight_shared_mem[td_offset+v] = d_weights[td_offset+v]

for i in range(4):
    for tdIdx in range(threads):
        base_index = 20*tdIdx + kernel_size*i
        for v in range(4):
            weight_registers[tdIdx][i*kernel_size+v] = weight_shared_mem[base_index+v]
        weight_registers[tdIdx][i*kernel_size+4] = weight_shared_mem[base_index+4]

In [93]:
for tileIdx in range(4):
    for tdIdx in range(threads):
        res = 0.0
        for dotIdx in range(kernel_size):
            for rowIdx in range(4):
                res += input_registers[tdIdx][tileIdx + dotIdx + (padded_input_length*rowIdx)] * \
                weight_registers[tdIdx][dotIdx + (kernel_size*rowIdx)]
        sum_reduce_shared_mem[tdIdx] = res 
    d_output[tileIdx] = sum(sum_reduce_shared_mem)
d_output_tensor = torch.tensor(d_output, dtype=torch.float32).view(1, 1, -1)

In [94]:
print(d_output)
output = conv(input_tensor)


print("Output:", output)

assert torch.allclose(d_output_tensor, output, rtol=1e-5, atol=0.1), "The outputs are not approximately equal"

print("The outputs are approximately equal within a margin of 10^-1.")

[21484270592.0, 28646747136.0, 28638356480.0, 21479550464.0]
Output: tensor([[[2.1484e+10, 2.8647e+10, 2.8638e+10, 2.1480e+10]]],
       grad_fn=<ConvolutionBackward0>)
The outputs are approximately equal within a margin of 10^-1.
