In [1]:
#### https://pytorch.org/tutorials/advanced/cpp_extension.html

# from torch.utils.cpp_extension import load
# lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])

In [2]:
# from torch.utils.cpp_extension import load

In [3]:
# bmm2x2 = load(name='bmm2x2', sources=['bmm2x2_cuda.cpp', 'bmm2x2_cuda_kernel.cu'], verbose=True)

In [4]:
# # Load Pytorch extension
# module_path = os.path.dirname(__file__)
# upfirdn2d_op = load(
#     "upfirdn2d_new",
#     sources=[
#         os.path.join(module_path, "upfirdn2d.cpp"),
#         os.path.join(module_path, "upfirdn2d_kernel.cu"),
#     ],
#     verbose=True,
# )

In [5]:
# bmm2x2_cuda.forward()

In [6]:
## From : Making of Pair Weight

In [7]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [8]:
device = torch.device("cuda:0")

In [9]:
class PairWeight(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2).transpose(0,1)
        x = torch.bmm(x, self.weight)
        x = x.transpose(1,0).reshape(bs, -1)
        return x

In [10]:
pairW = PairWeight(784).to(device)
pairW_s = torch.jit.script(pairW)

x = torch.randn(1000, 784).to(device)

In [11]:
# %timeit -n 100 -r 7 pairW(x) 
# %timeit pairW(x) 

In [12]:
# %timeit -n 100 -r 7 pairW_s(x) 

In [13]:
## Implementing my custom cuda code for bmm2x2

In [14]:
import bmm2x2_cuda

In [15]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bmm2x2_cuda.backward(
            grad_output, 
            grad_cell, 
            grad_output)
    
        return del_input, del_weights

In [16]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2).transpose(0,1)
        x = BMM2x2Function.apply(x.contiguous(), self.weight)
        x = x.transpose(1,0).reshape(bs, -1)
        return x

In [17]:
# bmm2x2_cuda.forward(torch.randn(10, 2, 2).to(device), torch.randn(10,2,2).to(device))

In [18]:
pairW2 = PairWeight2(784).to(device)

In [19]:
# %timeit -n 100 -r 7 pairW2(x) 
# %timeit pairW2(x) 

## Sparse Benchmark

In [20]:
## create sparse matrix row and col
N = 784
indices = []
for i in range(0, N, 2):
    indices.extend([(i,i), (i,i+1), (i+1,i), (i+1,i+1)])
indices = np.array(indices)

In [21]:
indices.shape

(1568, 2)

In [22]:
vals = torch.eye(2).unsqueeze(0).repeat_interleave(784//2, dim=0).reshape(-1)
# vals = torch.randn(len(indices))

In [23]:
sW = torch.sparse_coo_tensor(indices.T, vals, size=(784, 784)).to(device)

In [24]:
sW

tensor(indices=tensor([[  0,   0,   1,  ..., 782, 783, 783],
                       [  0,   1,   0,  ..., 783, 782, 783]]),
       values=tensor([1., 0., 0.,  ..., 0., 0., 1.]),
       device='cuda:0', size=(784, 784), nnz=1568, layout=torch.sparse_coo)

In [25]:
X = torch.randn(100, 784).to(device)

In [26]:
wt, xt = sW.t(), X.t()

In [44]:
%timeit (wt@xt).t()

641 µs ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [77]:
bs, dim = X.shape[0], X.shape[1]
x = X.view(bs, -1, 2).transpose(0,1).contiguous()

w = torch.eye(2).unsqueeze(0).repeat_interleave(784//2, dim=0).contiguous()

In [78]:
w.shape

torch.Size([392, 2, 2])

In [79]:
x = x.to(device)
w = w.to(device)

In [80]:
w.shape, x.shape

(torch.Size([392, 2, 2]), torch.Size([392, 100, 2]))

In [90]:
%timeit BMM2x2Function.apply(x.contiguous(), w).transpose(1,0).reshape(bs, -1)

146 µs ± 14.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [46]:
%timeit torch.bmm(x, w)

1.1 ms ± 594 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [34]:
### sparse mm -> 626 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
### BMM2x2 -> 77.7 µs ± 12.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### BMM2x2+contiguous -> 150 µs ± 9.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### torch bmm -> 1.11 ms ± 570 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [81]:
ansA = (wt@xt).t()
ansA

tensor([[ 1.5073,  0.7889, -0.4829,  ..., -0.6170,  1.6299,  1.2688],
        [ 0.3308, -0.0600,  0.0109,  ..., -2.6327, -0.0934,  0.4270],
        [-0.0868,  0.2504, -2.2972,  ...,  0.5827, -1.3525, -1.1179],
        ...,
        [-0.3604,  1.2213, -2.5813,  ..., -0.1434, -1.0056,  1.0076],
        [-0.8131, -0.3502,  0.8065,  ..., -0.3307,  0.8443,  0.2059],
        [ 0.6000, -0.2403,  0.7744,  ..., -1.9084,  0.1442,  0.6402]],
       device='cuda:0')

In [82]:
ansA.shape

torch.Size([100, 784])

In [83]:
ansB = BMM2x2Function.apply(x, w).transpose(1,0).reshape(bs, -1)
ansB

tensor([[ 1.5073,  0.7889, -0.4829,  ..., -0.6170,  1.6299,  1.2688],
        [ 0.3308, -0.0600,  0.0109,  ..., -2.6327, -0.0934,  0.4270],
        [-0.0868,  0.2504, -2.2972,  ...,  0.5827, -1.3525, -1.1179],
        ...,
        [-0.3604,  1.2213, -2.5813,  ..., -0.1434, -1.0056,  1.0076],
        [-0.8131, -0.3502,  0.8065,  ..., -0.3307,  0.8443,  0.2059],
        [ 0.6000, -0.2403,  0.7744,  ..., -1.9084,  0.1442,  0.6402]],
       device='cuda:0')

In [85]:
# ansB[1]

In [86]:
ansC = torch.bmm(x, w).transpose(1,0).reshape(bs, -1)
ansC

tensor([[ 1.5073,  0.7889, -0.4829,  ..., -0.6170,  1.6299,  1.2688],
        [ 0.3308, -0.0600,  0.0109,  ..., -2.6327, -0.0934,  0.4270],
        [-0.0868,  0.2504, -2.2972,  ...,  0.5827, -1.3525, -1.1179],
        ...,
        [-0.3604,  1.2213, -2.5813,  ..., -0.1434, -1.0056,  1.0076],
        [-0.8131, -0.3502,  0.8065,  ..., -0.3307,  0.8443,  0.2059],
        [ 0.6000, -0.2403,  0.7744,  ..., -1.9084,  0.1442,  0.6402]],
       device='cuda:0')

In [87]:
ansC.shape

torch.Size([100, 784])

In [73]:
x

tensor([[[ 1.5073,  0.7889],
         [ 0.3308, -0.0600],
         [-0.0868,  0.2504],
         ...,
         [-0.3604,  1.2213],
         [-0.8131, -0.3502],
         [ 0.6000, -0.2403]],

        [[-0.4829, -0.7721],
         [ 0.0109, -1.4208],
         [-2.2972, -0.0171],
         ...,
         [-2.5813,  2.5553],
         [ 0.8065,  0.2904],
         [ 0.7744, -2.4374]],

        [[-0.0731,  1.9099],
         [-0.2455, -0.1036],
         [ 0.7706,  0.8648],
         ...,
         [ 1.2436,  0.5248],
         [ 1.5319,  1.5474],
         [-2.1175,  0.1074]],

        ...,

        [[ 0.3615, -1.3613],
         [-2.2790,  0.0356],
         [ 0.0497, -0.9031],
         ...,
         [ 0.5367, -0.1016],
         [-2.5365,  1.5789],
         [ 0.7288, -0.3986]],

        [[-1.8808, -0.6170],
         [-1.3180, -2.6327],
         [-0.3965,  0.5827],
         ...,
         [ 0.0499, -0.1434],
         [-0.0149, -0.3307],
         [-1.2435, -1.9084]],

        [[ 1.6299,  1.2688],
       

In [41]:
## testing indices of Bx2x2 array
_w = torch.arange(10*2*2).reshape(10,2,2)

In [42]:
_w.shape

torch.Size([10, 2, 2])

In [43]:
_w

tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23]],

        [[24, 25],
         [26, 27]],

        [[28, 29],
         [30, 31]],

        [[32, 33],
         [34, 35]],

        [[36, 37],
         [38, 39]]])

In [58]:
_t = sW.to_dense()
_t

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')

In [70]:
(_t == torch.eye(784).to(device)).type(torch.long).sum().sqrt()

tensor(784., device='cuda:0')

In [63]:
w

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        ...,

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0')