What is element-wise square in pytorch?

go over every element

example for point-wise operation, which is important in ML, e.g., softmax, sin, cos

In [6]:
import torch

a = torch.tensor([1., 2., 3.])

print(torch.square(a))
print(a ** 2)
print(a * a)

def time_pytorch_function(func, input):
    # CUDA IS ASYNC so can't use python time module
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        func(input)

    start.record()
    func(input)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)

b = torch.randn(10000, 10000).cuda()

def square_2(a):
    return a * a

def square_3(a):
    return a ** 2

tensor([1., 4., 9.])
tensor([1., 4., 9.])
tensor([1., 4., 9.])


compare torch.square() with a * a

For profiling cuda:

**cuda is async**

if using python time module, just meassures the overhead that it takes to launch a kernel, not how much time kernel actually run

In [7]:
time_pytorch_function(torch.square, b)
time_pytorch_function(square_2, b)
time_pytorch_function(square_3, b)

print("=============")
print("Profiling torch.square")
print("=============")

# Now profile each function using pytorch profiler
with torch.profiler.profile() as prof:
    torch.square(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print("=============")
print("Profiling a * a")
print("=============")

with torch.profiler.profile() as prof:
    square_2(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print("=============")
print("Profiling a ** 2")
print("=============")

with torch.profiler.profile() as prof:
    square_3(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Profiling torch.square
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::square         0.15%       6.100us        40.98%       1.724ms       1.724ms             1  
                  aten::pow         1.92%      80.600us        40.83%       1.718ms       1.718ms             1  
          aten::result_type         0.04%       1.600us         0.04%       1.600us       1.600us             1  
                   aten::to         0.03%       1.300us         0.03%       1.300us       1.300us             1  
    Activity Buffer Request        21.57%     907.300us        21.57%     907.300us     907.300us             1  
           cudaLaunchKernel        17.28%     726.700us        17