In [2]:
from torchvision.models import mobilenet_v3_large
import torch

x = torch.randn([1, 3, 224, 224])
model = mobilenet_v3_large()

def hook(module, args, output):
    print(args[0].shape)

for mod in model.modules():
    if isinstance(mod, torch.nn.Conv2d) and mod.groups == mod.in_channels:
        mod.register_forward_hook(hook)

model(x);

torch.Size([1, 16, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 72, 56, 56])
torch.Size([1, 72, 56, 56])
torch.Size([1, 120, 28, 28])
torch.Size([1, 120, 28, 28])
torch.Size([1, 240, 28, 28])
torch.Size([1, 200, 14, 14])
torch.Size([1, 184, 14, 14])
torch.Size([1, 184, 14, 14])
torch.Size([1, 480, 14, 14])
torch.Size([1, 672, 14, 14])
torch.Size([1, 672, 14, 14])
torch.Size([1, 960, 7, 7])
torch.Size([1, 960, 7, 7])


In [4]:
import torch
from torch import nn

In [5]:
model = nn.Sequential(
    nn.Conv2d(32, 32, kernel_size=3, padding=1, groups=32, bias=False),
    nn.Conv2d(32, 64, kernel_size=1, bias=False)
).cuda()

x = torch.randn([1, 32, 64, 64], requires_grad=False).cuda()

In [6]:
model[0].weight.shape

torch.Size([32, 1, 3, 3])

In [7]:
model[1].weight.shape

torch.Size([64, 32, 1, 1])

In [8]:
for param in model.parameters():
    if param.grad is not None:
        param.grad.data.fill_(0.0)

y = model(x)
y.mean().backward()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [29]:
import torch
from torch import nn
from monarch_cuda import conv2d_forward

n, num_channels, h, w = 32, 16, 112, 112
#n, num_channels, h, w = 16, 960, 7, 7
n, num_channels, h, w = 32, 240, 28, 28
# n = 4
# num_channels = 512
# h = w = 64

depthwise_conv2d = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, groups=num_channels, bias=False).cuda()
x = torch.randn([n, num_channels, h, w], requires_grad=False).cuda()

@torch.no_grad()
def run():
    y = depthwise_conv2d(x)
    torch.cuda.synchronize()
    return y

y = run()

@torch.no_grad()
def run_my():
    y = conv2d_forward(x, depthwise_conv2d.weight.contiguous(), 1)
    torch.cuda.synchronize()
    return y

z = run_my()

print(torch.allclose(y, z))

True


In [30]:
%timeit -n 1000 run_my()

1.17 ms ± 61.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [31]:
%timeit -n 1000 run()

1.72 ms ± 8.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [1]:
import torch
from torch import nn
from monarch_cuda import conv2d_backward

# n, num_channels, h, w = 32, 16, 112, 112
#n, num_channels, h, w = 16, 960, 7, 7
# n, num_channels, h, w = 32, 240, 28, 28
n, num_channels, h, w = 1, 2, 5, 5
# n = 4
# num_channels = 512
# h = w = 64

depthwise_conv2d = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, groups=num_channels, bias=False).cuda()
x = torch.randn([n, num_channels, h, w], requires_grad=True).cuda()

states = {}
def log_bwd(module, grad_input, grad_output):
    states["grad_input"] = grad_input
    states["grad_output"] = grad_output

depthwise_conv2d.register_full_backward_hook(log_bwd)

y = depthwise_conv2d(x)
y.exp().sum().backward()

print(states)

{'grad_input': (tensor([[[[ 1.1685,  1.3499,  0.4227,  0.2620,  0.8351],
          [ 0.9209,  2.0667,  0.4926,  1.3879,  1.5142],
          [-0.4524,  1.5968,  2.3990, -0.5796,  1.1033],
          [-0.4179,  0.5972,  1.9350,  0.2136,  0.7750],
          [ 0.0033, -0.5234,  0.4973,  1.2309,  0.6199]],

         [[-0.6326, -0.5939, -0.1359, -0.0636,  0.0743],
          [-0.1904, -0.5326, -0.9812, -1.1572, -0.3421],
          [-0.3705,  0.3074, -0.0607, -0.3725, -0.1068],
          [-0.5645, -0.7567, -0.9802, -0.3866,  0.2590],
          [ 0.0823, -0.0074, -0.2114,  0.2522,  0.8198]]]], device='cuda:0'),), 'grad_output': (tensor([[[[1.8389, 1.4754, 0.7583, 0.8973, 1.1330],
          [1.4409, 2.0260, 1.6827, 0.4276, 1.5788],
          [0.4251, 3.3738, 0.6538, 1.0533, 3.3021],
          [0.7232, 0.5160, 2.7773, 0.8202, 0.4945],
          [0.5850, 0.3772, 1.7576, 1.0424, 0.4658]],

         [[0.4863, 0.7108, 1.1429, 1.1681, 1.4493],
          [2.3139, 1.0769, 0.4998, 0.3559, 0.9162],
       

In [3]:
print(states)
with torch.no_grad():
    din, dweights = conv2d_backward(states["grad_output"][0].clone(), x.clone(), depthwise_conv2d.weight.clone(), 1)
print(din)
print(states)
# print(din.shape)
# print(dweights)
# print(dweights.shape)
torch.allclose(states["grad_input"][0], din)

{'grad_input': (tensor([[[[ 1.1685,  1.3499,  0.4227,  0.2620,  0.8351],
          [ 0.9209,  2.0667,  0.4926,  1.3879,  1.5142],
          [-0.4524,  1.5968,  2.3990, -0.5796,  1.1033],
          [-0.4179,  0.5972,  1.9350,  0.2136,  0.7750],
          [ 0.0033, -0.5234,  0.4973,  1.2309,  0.6199]],

         [[-0.6326, -0.5939, -0.1359, -0.0636,  0.0743],
          [-0.1904, -0.5326, -0.9812, -1.1572, -0.3421],
          [-0.3705,  0.3074, -0.0607, -0.3725, -0.1068],
          [-0.5645, -0.7567, -0.9802, -0.3866,  0.2590],
          [ 0.0823, -0.0074, -0.2114,  0.2522,  0.8198]]]], device='cuda:0'),), 'grad_output': (tensor([[[[1.8389, 1.4754, 0.7583, 0.8973, 1.1330],
          [1.4409, 2.0260, 1.6827, 0.4276, 1.5788],
          [0.4251, 3.3738, 0.6538, 1.0533, 3.3021],
          [0.7232, 0.5160, 2.7773, 0.8202, 0.4945],
          [0.5850, 0.3772, 1.7576, 1.0424, 0.4658]],

         [[0.4863, 0.7108, 1.1429, 1.1681, 1.4493],
          [2.3139, 1.0769, 0.4998, 0.3559, 0.9162],
       

True

In [2]:
states["grad_output"][0]

tensor([[[[1.4013, 1.7406, 1.3197],
          [0.9091, 1.2141, 1.3131],
          [0.9499, 0.8774, 1.3109]],

         [[1.4815, 1.4291, 1.0360],
          [0.6996, 0.6797, 0.4451],
          [0.9602, 2.1373, 1.6364]]]], device='cuda:0')

In [4]:
states["grad_input"]

(tensor([[[[0.2053, 0.3073, 0.4549],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000]],
 
          [[0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000]]]], device='cuda:0'),)