In [2]:
from torchvision.models import mobilenet_v3_large
import torch

x = torch.randn([1, 3, 224, 224])
model = mobilenet_v3_large()

def hook(module, args, output):
    print(args[0].shape)

for mod in model.modules():
    if isinstance(mod, torch.nn.Conv2d) and mod.groups == mod.in_channels:
        mod.register_forward_hook(hook)

model(x);

torch.Size([1, 16, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 72, 56, 56])
torch.Size([1, 72, 56, 56])
torch.Size([1, 120, 28, 28])
torch.Size([1, 120, 28, 28])
torch.Size([1, 240, 28, 28])
torch.Size([1, 200, 14, 14])
torch.Size([1, 184, 14, 14])
torch.Size([1, 184, 14, 14])
torch.Size([1, 480, 14, 14])
torch.Size([1, 672, 14, 14])
torch.Size([1, 672, 14, 14])
torch.Size([1, 960, 7, 7])
torch.Size([1, 960, 7, 7])


In [4]:
import torch
from torch import nn

In [5]:
model = nn.Sequential(
    nn.Conv2d(32, 32, kernel_size=3, padding=1, groups=32, bias=False),
    nn.Conv2d(32, 64, kernel_size=1, bias=False)
).cuda()

x = torch.randn([1, 32, 64, 64], requires_grad=False).cuda()

In [6]:
model[0].weight.shape

torch.Size([32, 1, 3, 3])

In [7]:
model[1].weight.shape

torch.Size([64, 32, 1, 1])

In [8]:
for param in model.parameters():
    if param.grad is not None:
        param.grad.data.fill_(0.0)

y = model(x)
y.mean().backward()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [29]:
import torch
from torch import nn
from monarch_cuda import conv2d_forward

n, num_channels, h, w = 32, 16, 112, 112
#n, num_channels, h, w = 16, 960, 7, 7
n, num_channels, h, w = 32, 240, 28, 28
# n = 4
# num_channels = 512
# h = w = 64

depthwise_conv2d = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, groups=num_channels, bias=False).cuda()
x = torch.randn([n, num_channels, h, w], requires_grad=False).cuda()

@torch.no_grad()
def run():
    y = depthwise_conv2d(x)
    torch.cuda.synchronize()
    return y

y = run()

@torch.no_grad()
def run_my():
    y = conv2d_forward(x, depthwise_conv2d.weight.contiguous(), 1)
    torch.cuda.synchronize()
    return y

z = run_my()

print(torch.allclose(y, z))

True


In [30]:
%timeit -n 1000 run_my()

1.17 ms ± 61.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [31]:
%timeit -n 1000 run()

1.72 ms ± 8.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
import torch
from torch import nn
from monarch_cuda import conv2d_backward

# n, num_channels, h, w = 32, 16, 112, 112
#n, num_channels, h, w = 16, 960, 7, 7
# n, num_channels, h, w = 32, 240, 28, 28
n, num_channels, h, w = 1, 2, 5, 5
# n = 4
# num_channels = 512
# h = w = 64

depthwise_conv2d = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, groups=num_channels, bias=False).cuda()
x = torch.randn([n, num_channels, h, w], requires_grad=True).cuda()

states = {}
def log_bwd(module, grad_input, grad_output):
    states["grad_input"] = grad_input
    states["grad_output"] = grad_output

depthwise_conv2d.register_full_backward_hook(log_bwd)

y = depthwise_conv2d(x)
y.exp().sum().backward()

print(states)

{'grad_input': (tensor([[[[-1.1761, -0.4659, -1.0021, -1.1099, -0.4924],
          [-1.6424, -0.9335, -1.5599, -0.5499, -0.0767],
          [-0.7760, -1.1514,  0.8276, -0.9263, -0.4568],
          [ 0.4002, -1.2251, -0.7052, -0.5529, -0.4925],
          [-0.2923,  0.1769,  0.2545, -0.1098,  0.0704]],

         [[ 0.4894,  0.2689,  0.3237,  0.3378,  0.2755],
          [-0.1332,  0.6338,  0.4120,  0.3929,  0.3878],
          [ 0.0254,  0.3548,  0.3603,  0.3797,  0.4281],
          [ 0.0468,  0.2192,  0.5309,  0.2608,  0.3510],
          [-0.0229,  0.1157, -0.0219,  0.2763,  0.2349]]]], device='cuda:0'),), 'grad_output': (tensor([[[[2.3219, 0.4889, 0.7353, 1.4784, 0.8255],
          [1.7288, 0.7952, 1.1299, 1.7820, 0.5137],
          [0.8228, 4.3713, 0.4083, 0.7533, 0.7121],
          [0.5583, 0.4888, 1.1688, 1.1043, 0.5728],
          [0.9874, 1.2586, 0.7029, 1.1646, 0.9361]],

         [[1.8999, 0.8865, 0.9262, 1.1549, 0.8149],
          [1.1085, 1.5337, 1.0881, 1.2519, 0.9353],
       

In [5]:
print(states)
with torch.no_grad():
    din, dweights = conv2d_backward(states["grad_output"][0], x.clone(), depthwise_conv2d.weight, 1)
print(din)
print(states)
# print(din.shape)
# print(dweights)
# print(dweights.shape)
torch.allclose(states["grad_input"][0], din)

{'grad_input': (tensor([[[[-1.1761, -0.4659, -1.0021, -1.1099, -0.4924],
          [-1.6424, -0.9335, -1.5599, -0.5499, -0.0767],
          [-0.7760, -1.1514,  0.8276, -0.9263, -0.4568],
          [ 0.4002, -1.2251, -0.7052, -0.5529, -0.4925],
          [-0.2923,  0.1769,  0.2545, -0.1098,  0.0704]],

         [[ 0.4894,  0.2689,  0.3237,  0.3378,  0.2755],
          [-0.1332,  0.6338,  0.4120,  0.3929,  0.3878],
          [ 0.0254,  0.3548,  0.3603,  0.3797,  0.4281],
          [ 0.0468,  0.2192,  0.5309,  0.2608,  0.3510],
          [-0.0229,  0.1157, -0.0219,  0.2763,  0.2349]]]], device='cuda:0'),), 'grad_output': (tensor([[[[2.3219, 0.4889, 0.7353, 1.4784, 0.8255],
          [1.7288, 0.7952, 1.1299, 1.7820, 0.5137],
          [0.8228, 4.3713, 0.4083, 0.7533, 0.7121],
          [0.5583, 0.4888, 1.1688, 1.1043, 0.5728],
          [0.9874, 1.2586, 0.7029, 1.1646, 0.9361]],

         [[1.8999, 0.8865, 0.9262, 1.1549, 0.8149],
          [1.1085, 1.5337, 1.0881, 1.2519, 0.9353],
       

True

In [2]:
states["grad_output"][0]

tensor([[[[1.3339, 0.6068, 1.4437, 0.4591, 3.2536],
          [0.4609, 0.8319, 1.2438, 1.3593, 0.4581],
          [3.2788, 1.1048, 1.0025, 0.9150, 0.8091],
          [0.5265, 2.7995, 1.8639, 1.1154, 1.9408],
          [1.0752, 0.4621, 0.7071, 0.9627, 0.6054]],

         [[0.9014, 2.2644, 1.0638, 0.7618, 1.0463],
          [0.4676, 3.2272, 1.1820, 0.8296, 3.7058],
          [0.9902, 1.6993, 0.2760, 1.1363, 3.5201],
          [0.7246, 0.8807, 0.8306, 1.5082, 1.4903],
          [0.9297, 2.2455, 1.0649, 0.8699, 1.2829]]]], device='cuda:0')

In [3]:
states["grad_input"]

(tensor([[[[-0.3243,  0.2552,  0.0164,  0.2652, -0.7573],
           [ 0.9317,  0.1859,  0.2556,  0.3482,  1.0515],
           [-0.5501,  1.2759,  0.7390,  0.3441,  0.0424],
           [ 0.8952, -1.3240,  0.2384,  0.2251, -0.2865],
           [ 0.2867,  0.9382, -0.4277, -0.1551,  0.1596]],
 
          [[ 0.7385, -0.4478, -0.8459, -0.0750, -0.9847],
           [ 1.1138, -0.2448, -0.7152,  0.7494, -0.9785],
           [ 1.1503, -0.8641, -0.2171,  1.4930, -1.4552],
           [ 0.3670, -0.7512,  0.2272,  0.7006, -1.5908],
           [ 0.7140,  0.2325, -0.1341,  0.1210, -0.4547]]]], device='cuda:0'),)