In [1]:
import torch 
import time

print(torch.__version__)

warmup_iter_raw = 1000
prof_iter_raw = 1000

def test_convtranspose2d(bs, c, hw, ks, stride, pad, outpad, dilation):
    warmup_iter = warmup_iter_raw
    prof_iter = prof_iter_raw
    
    if c >= 256: 
        warmup_iter //= 10 
        prof_iter //= 10 
        
    print(bs, c, hw, ks, stride, pad, outpad, dilation)
    x = torch.randn(bs, c, hw, hw, device='cuda', dtype=torch.half, requires_grad=True)  
    conv = torch.nn.ConvTranspose2d(
        in_channels=c, 
        out_channels=c, 
        kernel_size=ks, 
        stride=stride, 
        padding=pad, 
        output_padding=outpad, 
        groups=c,
        bias=False, 
        dilation=dilation
    ).half().cuda()

    y:torch.Tensor = conv(x)
    g = torch.ones_like(y)

    for warm_up in range(warmup_iter): 
        y = conv(x)
        y.backward(g)

    torch.cuda.synchronize()
    ts = time.time() 

    for it in range(prof_iter): 
        y = conv(x)

    torch.cuda.synchronize()
    te = time.time()

    t_forward = (te-ts)/prof_iter
    print(f'forward {t_forward: .3e}')

    ts = time.time()
    torch.cuda.synchronize() 

    for it in range(prof_iter): 
        y.backward(g, retain_graph=True)

    torch.cuda.synchronize()
    te = time.time()

    t_backward = (te-ts)/prof_iter
    print(f'backward {t_backward: .3e}')

    print(f'total {t_forward + t_backward: .3e}')
    
    print()

1.6.0a0+77b4e2d


In [2]:
def test():  
    print('bs c hw ks stride pad outpad dilation\n')
    # def test_convtranspose2d(bs, c, hw, ks, stride, pad, outpad, dilation)
    
    test_convtranspose2d(32, 128, 7, 1, 2, 0, 0, 2)
    test_convtranspose2d(32, 128, 7, 1, 1, 0, 0, 2)
    test_convtranspose2d(8, 128, 14, 1, 1, 0, 0, 3)
    test_convtranspose2d(1, 128, 7, 1, 3, 0, 1, 2)
    
    test_convtranspose2d(32, 512, 7, 1, 3, 0, 2, 2)
    test_convtranspose2d(8, 512, 14, 3, 2, 1, 1, 1)
    test_convtranspose2d(32, 256, 7, 1, 3, 0, 1, 1)
    test_convtranspose2d(1, 512, 14, 3, 3, 1, 1, 1)

In [15]:
print('master')
test()

master
bs c hw ks stride pad outpad dilation

32 128 7 1 2 0 0 2
forward  7.415e-05
backward  2.432e-03
total  2.506e-03

32 128 7 1 1 0 0 2
forward  7.270e-05
backward  3.238e-03
total  3.311e-03

8 128 14 1 1 0 0 3
forward  7.214e-05
backward  3.232e-03
total  3.304e-03

1 128 7 1 3 0 1 2
forward  7.211e-05
backward  3.110e-03
total  3.182e-03

32 512 7 1 3 0 2 2
forward  1.974e-01
backward  3.896e-01
total  5.870e-01

8 512 14 3 2 1 1 1
forward  8.204e-02
backward  1.598e-01
total  2.418e-01

32 256 7 1 3 0 1 1
forward  9.928e-02
backward  1.947e-01
total  2.940e-01

1 512 14 3 3 1 1 1
forward  3.833e-02
backward  8.089e-02
total  1.192e-01



In [3]:
print('without dilation check')
test()

without dilation check
bs c hw ks stride pad outpad dilation

32 128 7 1 2 0 0 2
forward  9.304e-05
backward  2.358e-03
total  2.451e-03

32 128 7 1 1 0 0 2
forward  7.569e-05
backward  3.104e-03
total  3.180e-03

8 128 14 1 1 0 0 3
forward  7.390e-05
backward  3.123e-03
total  3.197e-03

1 128 7 1 3 0 1 2
forward  7.711e-05
backward  2.982e-03
total  3.059e-03

32 512 7 1 3 0 2 2
forward  2.085e-04
backward  1.152e-02
total  1.173e-02

8 512 14 3 2 1 1 1
forward  1.887e-04
backward  1.424e-04
total  3.311e-04

32 256 7 1 3 0 1 1
forward  1.264e-04
backward  3.502e-03
total  3.629e-03

1 512 14 3 3 1 1 1
forward  7.291e-05
backward  1.455e-04
total  2.184e-04

