In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0")

In [3]:
class PairWeight(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2).transpose(0,1)
        x = torch.bmm(x, self.weight)
        x = x.transpose(1,0).reshape(bs, -1)
        return x

In [4]:
pairW = PairWeight(784).to(device)
pairW_s = torch.jit.script(pairW)

x = torch.randn(1000, 784).to(device)

In [5]:
pairW(x)

tensor([[-0.2235, -0.6298,  0.9451,  ..., -0.2658, -1.2232,  1.2944],
        [-0.1163,  1.3137, -0.3774,  ..., -1.4213, -0.1092, -1.4871],
        [-1.2374,  0.2737,  0.6669,  ..., -0.2756, -1.0252,  0.6792],
        ...,
        [-0.4211, -1.9955, -2.0891,  ..., -0.8854,  1.2063,  1.6415],
        [-0.5472, -0.5235, -0.1416,  ...,  0.9046, -0.0857, -0.6555],
        [-0.5857, -0.3508,  2.0595,  ...,  0.3208, -0.7738,  0.7363]],
       device='cuda:0', grad_fn=<UnsafeViewBackward>)

In [6]:
## Implementing my custom cuda code for bmm2x2

In [7]:
import bmm2x2_cuda

In [8]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bmm2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [9]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2)
        x = BMM2x2Function.apply(x, self.weight)
        x = x.view(bs, -1)
        return x

In [10]:
# bmm2x2_cuda.forward(torch.randn(10, 2, 2).to(device), torch.randn(10,2,2).to(device))

In [11]:
pairW2 = PairWeight2(784).to(device)

In [12]:
pairW2(x) 

tensor([[-0.2235, -0.6298,  0.9451,  ..., -0.2658, -1.2232,  1.2944],
        [-0.1163,  1.3137, -0.3774,  ..., -1.4213, -0.1092, -1.4871],
        [-1.2374,  0.2737,  0.6669,  ..., -0.2756, -1.0252,  0.6792],
        ...,
        [-0.4211, -1.9955, -2.0891,  ..., -0.8854,  1.2063,  1.6415],
        [-0.5472, -0.5235, -0.1416,  ...,  0.9046, -0.0857, -0.6555],
        [-0.5857, -0.3508,  2.0595,  ...,  0.3208, -0.7738,  0.7363]],
       device='cuda:0', grad_fn=<ViewBackward>)

## Sparse Benchmark

In [13]:
## create sparse matrix row and col
N = 784
B = 100
indices = []
for i in range(0, N, 2):
    indices.extend([(i,i), (i,i+1), (i+1,i), (i+1,i+1)])
indices = np.array(indices)

In [14]:
indices.shape

(1568, 2)

In [15]:
vals = torch.eye(2).unsqueeze(0).repeat_interleave(N//2, dim=0).reshape(-1)
# vals = torch.randn(len(indices))

In [16]:
sW = torch.sparse_coo_tensor(indices.T, vals, size=(N, N)).to(device)

In [17]:
sW

tensor(indices=tensor([[  0,   0,   1,  ..., 782, 783, 783],
                       [  0,   1,   0,  ..., 783, 782, 783]]),
       values=tensor([1., 0., 0.,  ..., 0., 0., 1.]),
       device='cuda:0', size=(784, 784), nnz=1568, layout=torch.sparse_coo)

In [18]:
X = torch.randn(B, N).to(device)

In [19]:
wt, xt = sW.t(), X.t()

In [20]:
# %timeit (wt@xt).t()

In [21]:
bs, dim = X.shape[0], X.shape[1]
x = X.view(bs, -1, 2)#.transpose(0,1).contiguous()

w = torch.eye(2).unsqueeze(0).repeat_interleave(N//2, dim=0)#.contiguous()

In [22]:
w.shape

torch.Size([392, 2, 2])

In [23]:
x = x.to(device)
w = w.to(device)

In [24]:
w.shape, x.shape

(torch.Size([392, 2, 2]), torch.Size([100, 392, 2]))

In [25]:
# %timeit BMM2x2Function.apply(x, w).view(bs, -1)

In [26]:
# %timeit torch.bmm(x, w)

In [27]:
_x = torch.randn(B, N).to(device)
_w = torch.randn(N, N).to(device)

In [28]:
# %timeit torch.mm(_x, _w)reshape

In [29]:
########## x -> (100, 784) w -> sparse or dense ##########

### sparse mm -> 626 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
### BMM2x2 -> 77.7 µs ± 12.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### BMM2x2+contiguous -> 150 µs ± 9.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### BMM2x2+reshape -> 80 µs ± 6.22 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### torch bmm -> 1.11 ms ± 570 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
### dense mm -> 754 µs ± 283 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

########## x -> (100, 7840)
## sparse mm -> 6.1 ms ± 29.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
## BMM2x2 -> 804 µs ± 56.1 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
## Dense -> 70.7 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)


########## x -> (1000, 784)
## sparse mm -> 3.57 ms ± 80.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
## BMM2x2 -> 801 µs ± 63 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
## Dense -> 4.44 ms ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [30]:
ansA = (wt@xt).t()
ansA

tensor([[ 1.3629, -0.5071, -0.9541,  ..., -1.2122, -1.5663, -0.4602],
        [-0.2374, -0.4753,  1.2988,  ..., -0.8158,  0.7913,  0.1759],
        [-0.1011, -1.4976,  0.1732,  ...,  0.6395, -0.5420,  0.5178],
        ...,
        [-0.0517,  0.0322,  0.0651,  ...,  0.2235, -0.5103,  0.0496],
        [-0.4108,  0.6535,  0.2936,  ..., -1.0167, -0.0673,  0.0679],
        [ 2.1288,  2.1379, -0.3316,  ...,  1.2718,  0.0641, -0.5905]],
       device='cuda:0')

In [31]:
ansA.shape

torch.Size([100, 784])

In [32]:
ansB = BMM2x2Function.apply(x, w).view(bs, -1)
ansB

tensor([[ 1.3629, -0.5071, -0.9541,  ..., -1.2122, -1.5663, -0.4602],
        [-0.2374, -0.4753,  1.2988,  ..., -0.8158,  0.7913,  0.1759],
        [-0.1011, -1.4976,  0.1732,  ...,  0.6395, -0.5420,  0.5178],
        ...,
        [-0.0517,  0.0322,  0.0651,  ...,  0.2235, -0.5103,  0.0496],
        [-0.4108,  0.6535,  0.2936,  ..., -1.0167, -0.0673,  0.0679],
        [ 2.1288,  2.1379, -0.3316,  ...,  1.2718,  0.0641, -0.5905]],
       device='cuda:0')

In [33]:
# ansB[1]

In [34]:
# ansC = torch.bmm(x, w).transpose(1,0).reshape(bs, -1)
# ansC

In [35]:
# ansC.shape

In [36]:
x

tensor([[[ 1.3629, -0.5071],
         [-0.9541,  0.3553],
         [-0.9223, -0.0636],
         ...,
         [ 0.0351, -0.1227],
         [ 0.9231, -1.2122],
         [-1.5663, -0.4602]],

        [[-0.2374, -0.4753],
         [ 1.2988, -1.0792],
         [ 0.8484, -2.0350],
         ...,
         [-0.6424, -0.1664],
         [-0.3492, -0.8158],
         [ 0.7913,  0.1759]],

        [[-0.1011, -1.4976],
         [ 0.1732, -0.1137],
         [ 1.2251, -0.6367],
         ...,
         [-1.1096, -0.2325],
         [ 0.0678,  0.6395],
         [-0.5420,  0.5178]],

        ...,

        [[-0.0517,  0.0322],
         [ 0.0651,  0.4915],
         [ 2.8140,  1.5796],
         ...,
         [ 0.6711,  0.4032],
         [-1.0173,  0.2235],
         [-0.5103,  0.0496]],

        [[-0.4108,  0.6535],
         [ 0.2936,  0.4095],
         [-0.7818, -0.3301],
         ...,
         [ 1.1950, -0.6354],
         [-1.2095, -1.0167],
         [-0.0673,  0.0679]],

        [[ 2.1288,  2.1379],
       

In [37]:
## testing indices of Bx2x2 array
_w = torch.arange(10*2*2).reshape(10,2,2)

In [38]:
_w.shape

torch.Size([10, 2, 2])

In [39]:
_w

tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23]],

        [[24, 25],
         [26, 27]],

        [[28, 29],
         [30, 31]],

        [[32, 33],
         [34, 35]],

        [[36, 37],
         [38, 39]]])

In [40]:
_t = sW.to_dense()
_t

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')

In [41]:
(_t == torch.eye(784).to(device)).type(torch.long).sum().sqrt()

tensor(784., device='cuda:0')

In [42]:
w = torch.randn_like(w)

## Test Backward Function

In [43]:
x.requires_grad = True
w.requires_grad = True

x.grad = None
w.grad = None

In [44]:
x

tensor([[[ 1.3629, -0.5071],
         [-0.9541,  0.3553],
         [-0.9223, -0.0636],
         ...,
         [ 0.0351, -0.1227],
         [ 0.9231, -1.2122],
         [-1.5663, -0.4602]],

        [[-0.2374, -0.4753],
         [ 1.2988, -1.0792],
         [ 0.8484, -2.0350],
         ...,
         [-0.6424, -0.1664],
         [-0.3492, -0.8158],
         [ 0.7913,  0.1759]],

        [[-0.1011, -1.4976],
         [ 0.1732, -0.1137],
         [ 1.2251, -0.6367],
         ...,
         [-1.1096, -0.2325],
         [ 0.0678,  0.6395],
         [-0.5420,  0.5178]],

        ...,

        [[-0.0517,  0.0322],
         [ 0.0651,  0.4915],
         [ 2.8140,  1.5796],
         ...,
         [ 0.6711,  0.4032],
         [-1.0173,  0.2235],
         [-0.5103,  0.0496]],

        [[-0.4108,  0.6535],
         [ 0.2936,  0.4095],
         [-0.7818, -0.3301],
         ...,
         [ 1.1950, -0.6354],
         [-1.2095, -1.0167],
         [-0.0673,  0.0679]],

        [[ 2.1288,  2.1379],
       

In [45]:
y = BMM2x2Function.apply(x, w).view(bs, -1)

In [46]:
y.mean().backward()

In [47]:
w.grad

tensor([[[-1.0022e-04, -1.0022e-04],
         [-2.2904e-05, -2.2904e-05]],

        [[ 4.1111e-04,  4.1111e-04],
         [-8.4858e-05, -8.4858e-05]],

        [[ 5.6670e-05,  5.6670e-05],
         [-8.6980e-06, -8.6980e-06]],

        ...,

        [[ 7.5286e-05,  7.5286e-05],
         [ 7.4401e-06,  7.4401e-06]],

        [[-1.0035e-04, -1.0035e-04],
         [ 2.0920e-05,  2.0920e-05]],

        [[-1.6806e-04, -1.6806e-04],
         [-3.2195e-05, -3.2195e-05]]], device='cuda:0')

In [48]:
x.grad

tensor([[[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        ...,

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06

In [49]:
xgrad = x.grad
wgrad = w.grad

### test using bmm

In [50]:
x.grad = None
w.grad = None

In [51]:
y_ = torch.bmm(x.transpose(1,0), w).transpose(1,0).reshape(B, -1)

In [52]:
y_ - y

tensor([[-1.1921e-07,  0.0000e+00,  0.0000e+00,  ..., -5.9605e-08,
          2.9802e-08,  0.0000e+00],
        [ 1.1176e-08,  0.0000e+00,  0.0000e+00,  ..., -3.7253e-09,
          1.4901e-08,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00, -5.9605e-08,  ...,  0.0000e+00,
          0.0000e+00,  2.9802e-08],
        [-5.9605e-08, -2.9802e-08,  0.0000e+00,  ...,  2.9802e-08,
          0.0000e+00,  0.0000e+00],
        [-1.1921e-07,  0.0000e+00,  0.0000e+00,  ..., -5.9605e-08,
          0.0000e+00,  0.0000e+00]], device='cuda:0', grad_fn=<SubBackward0>)

In [53]:
y_.mean().backward()

In [54]:
x.grad

tensor([[[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        ...,

        [[ 2.7387e-05, -6.3784e-06],
         [-1.5906e-05, -4.1040e-05],
         [ 4.5410e-05,  5.9082e-08],
         ...,
         [-3.2505e-05,  3.7755e-06],
         [-4.7682e-06, -1.0622e-05],
         [-2.4337e-06,  3.5021e-06]],

        [[ 2.7387e-05, -6.3784e-06

In [55]:
w.grad

tensor([[[-1.0022e-04, -1.0022e-04],
         [-2.2904e-05, -2.2904e-05]],

        [[ 4.1111e-04,  4.1111e-04],
         [-8.4858e-05, -8.4858e-05]],

        [[ 5.6670e-05,  5.6670e-05],
         [-8.6980e-06, -8.6980e-06]],

        ...,

        [[ 7.5286e-05,  7.5286e-05],
         [ 7.4401e-06,  7.4401e-06]],

        [[-1.0035e-04, -1.0035e-04],
         [ 2.0920e-05,  2.0920e-05]],

        [[-1.6806e-04, -1.6806e-04],
         [-3.2195e-05, -3.2195e-05]]], device='cuda:0')

In [56]:
torch.testing.assert_allclose(x.grad, xgrad)

In [57]:
torch.testing.assert_allclose(w.grad, wgrad)

### Timing the backward

In [81]:
## create sparse matrix row and col
N = 784
B = 1000
indices = []
for i in range(0, N, 2):
    indices.extend([(i,i), (i,i+1), (i+1,i), (i+1,i+1)])
indices = np.array(indices)
vals = torch.eye(2).unsqueeze(0).repeat_interleave(N//2, dim=0).reshape(-1)
sW = torch.sparse_coo_tensor(indices.T, vals, size=(N, N)).to(device)

In [82]:
X = torch.randn(B, N).to(device)

In [83]:
wt, xt = sW.t(), X.t()

In [84]:
bs, dim = X.shape[0], X.shape[1]
x = X.view(bs, -1, 2)#.transpose(0,1).contiguous()
w = torch.eye(2).unsqueeze(0).repeat_interleave(N//2, dim=0)#.contiguous()

In [85]:
wt.requires_grad = True
xt.requires_grad = True

In [86]:
x = x.to(device)
w = w.to(device)
w.requires_grad = True
x.requires_grad = True

In [87]:
_x = torch.randn(B, N).to(device)
_w = torch.randn(N, N).to(device)
_w.requires_grad = True
_x.requires_grad = True

In [88]:
### now perform operation using each method, and then benchmark

In [89]:
_y0 = BMM2x2Function.apply(x, w).mean()
_y1 = torch.bmm(x.transpose(1,0), w).mean()
_y2 = torch.sparse.mm(wt,xt).t().mean()
_y3 = torch.mm(_x, _w).mean()

In [90]:
%timeit _y0.backward(retain_graph=True)

2.43 ms ± 119 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [91]:
# %timeit _y1.backward(retain_graph=True)

In [92]:
# %timeit _y2.backward(retain_graph=True)

In [None]:
# %timeit _y3.backward(retain_graph=True)
# %timeit -n 100 -r 7 _y3.backward(retain_graph=True)

In [None]:
### For N=784, B=1000 ###### sequentially -- bmm2x2, bmm, sparsemm, densemm
# 4.75 ms ± 878 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 14.4 ms ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 6.54 ms ± 250 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 10.1 ms ± 7.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
## BMM2x2 backward_v2
# 2.43 ms ± 239 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

### for N=7840, B=100
# 4.44 ms ± 502 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 19.8 ms ± 3.71 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 51 ms ± 745 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 185 ms ± 66.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
## BMM2x2 backward_v2
# 2.32 ms ± 84.4 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

### for N=784, B=100
# 413 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 2.12 ms ± 4.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 1.67 ms ± 87.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 2.13 ms ± 745 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
## BMM2x2 backward_v2
# 244 µs ± 75.6 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [None]:
asdfadsfsfds

## linear grad check

In [None]:
## https://stats.stackexchange.com/questions/358786/mean-or-sum-of-gradients-for-weight-updates-in-sgd
### sum or mean for gradient in sgd ?? Ans: Sum on grad, mean on loss

In [None]:
lin = nn.Linear(20, 10, bias=False)
x = torch.randn(5, 20)
x.requires_grad = True

In [None]:
lin(x).mean().backward()

In [None]:
xgrad = x.grad

In [None]:
wgrad = lin.weight.grad

In [None]:
x.grad = None
lin.weight.grad = None

In [None]:
(x@lin.weight.t()).mean().backward()

In [None]:
x.grad-xgrad

In [None]:
lin.weight.grad-wgrad