In [1]:
import torch
import torch.nn as nn
from time import time

In [32]:
# Preparing experiment variable
device = 'cuda' if torch.cuda.is_available() else 'cpu'

kernel0 = torch.tensor([
            [[ [0,0,0], [1,-1,0], [0,0,0] ]],
            [[ [0,0,0], [0,-1,1], [0,0,0] ]],
            [[ [0,1,0], [0,-1,0], [0,0,0] ]],
            [[ [0,0,0], [0,-1,0], [0,1,0] ]]
        ], device='cuda', dtype=torch.float32)
conv0 = nn.Conv2d(1, 4, kernel_size=(3,3), bias=False, device=device, dtype=torch.float32)
with torch.no_grad():
    conv0.weight[:] = kernel0
    
kernel1 = torch.tensor([
            [[ [0,0,0], [0,-1,0], [0,1,0] ]],
        ], device=device, dtype=torch.float32)
conv1 = nn.Conv2d(1, 1, kernel_size=(3,3), bias=False, device=device, dtype=torch.float32)
with torch.no_grad():
    conv1.weight[:] = kernel1

kernel2 = torch.tensor([
            [[ [0,0,0], [1,-1,0], [0,0,0] ]],
        ], device=device, dtype=torch.float32)
conv2 = nn.Conv2d(1, 1, kernel_size=(3,3), bias=False, device=device, dtype=torch.float32)
with torch.no_grad():
    conv2.weight[:] = kernel2

kernel3 = torch.tensor([
            [[ [0,0,0], [0,-1,1], [0,0,0] ]],
        ], device=device, dtype=torch.float32)
conv3 = nn.Conv2d(1, 1, kernel_size=(3,3), bias=False, device=device, dtype=torch.float32)
with torch.no_grad():
    conv3.weight[:] = kernel3

kernel4 = torch.tensor([
            [[ [0,1,0], [0,-1,0], [0,0,0] ]],
        ], device=device, dtype=torch.float32)
conv4 = nn.Conv2d(1, 1, kernel_size=(3,3), bias=False, device=device, dtype=torch.float32)
with torch.no_grad():
    conv4.weight[:] = kernel4





map_split = torch.rand((1,1,1000,1000), device='cuda', dtype=torch.float32)
map_merge = map_split.clone()
map_conduct_merge = torch.ones((1,4,998,998), device='cuda', dtype=torch.float32)*2
map_conduct = torch.ones((1,1,998,998), device='cuda', dtype=torch.float32)*2



In [65]:
def merge_kernel_update():
    with torch.inference_mode():
        diff_map = conv0(map_merge)
        diff = torch.sum(diff_map * map_conduct_merge, dim=1, keepdim=True)
        map_merge[:, :, 1:-1, 1:-1] += diff
        
def merge_kernel_update2():
    with torch.inference_mode():
        diff_map = conv0(map_merge) * map_conduct_merge
        
        diff = diff_map[:,0,:,:] + diff_map[:,1,:,:] +diff_map[:,2,:,:] + diff_map[:,3,:,:]
        diff.unsqueeze(0)
        
        
        map_merge[:, :, 1:-1, 1:-1] += diff


def split_kernel_update():
    with torch.inference_mode():
        diff1 = conv1(map_split)
        diff2 = conv2(map_split)
        diff3 = conv3(map_split)
        diff4 = conv4(map_split)

        diff = diff1 + diff2 + diff3 + diff4

        map_split[:, :, 1:-1, 1:-1] += diff

@torch.compile    
def merge_kernel_update_compile():
    with torch.inference_mode():
        diff_map = conv0(map_merge)
        diff = torch.sum(diff_map * map_conduct, dim=1, keepdim=True)
        map_merge[:, :, 1:-1, 1:-1] += diff

@torch.compile     
def split_kernel_update_compile():
    with torch.inference_mode():
        diff1 = conv1(map_split)
        diff2 = conv2(map_split)
        diff3 = conv3(map_split)
        diff4 = conv4(map_split)
        
        diff = diff1 + diff2 + diff3 + diff4

        map_split[:, :, 1:-1, 1:-1] += diff
        
def time_measure(func, num_run):
    start = time()
    for i in range(num_run):
        func()
    end = time()
    avg = (end-start) * 1000000 / num_run
    print(avg)



In [71]:
%%timeit -r 100 -n 1000
split_kernel_update()

196 μs ± 22.4 μs per loop (mean ± std. dev. of 100 runs, 1,000 loops each)


In [72]:
%%timeit -r 100 -n 1000
split_kernel_update_compile()

158 μs ± 8.8 μs per loop (mean ± std. dev. of 100 runs, 1,000 loops each)


In [73]:
%%timeit  -r 100 -n 1000
merge_kernel_update()

546 μs ± 14.7 μs per loop (mean ± std. dev. of 100 runs, 1,000 loops each)


In [74]:
%%timeit  -r 100 -n 1000
merge_kernel_update_compile()

307 μs ± 18.3 μs per loop (mean ± std. dev. of 100 runs, 1,000 loops each)
