In [None]:
from time import perf_counter

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn.functional import interpolate


In [None]:
def time_sync():
    torch.cuda.synchronize()
    # t = time()
    t = perf_counter()
    return t

In [None]:
## Measure time using time_sync with time.perf_counter

Npix = 128
batch_size = 32
scale_factor = (1,3,3)
Niter = 10000

times = []
for i in range(Niter):
    measurements = torch.rand(size=(1,1,batch_size,Npix,Npix), dtype=torch.float32, device='cuda')
    
    t0 = time_sync()    
    measurements_up = interpolate(measurements, scale_factor = scale_factor, mode = 'area')
    t1 = time_sync()    
    times.append(t1-t0)
        
print(f"measurements.shape = {measurements.shape}")
print(f"measurements_up.shape = {measurements_up.shape}")
print(f"min = {np.min(times):.3g}, mean = {np.mean(times):.3g}, max = {np.max(times):.3g}, std = {np.std(times):.3g}")

plt.figure()
plt.hist(times, bins=20)
plt.show()

In [None]:
## Measure time using torch.cuda.Event

Npix = 1024
batch_size = 32
scale_factor = (1,3,3)
Niter = 10000

times = []
for i in range(Niter):
    measurements = torch.rand(size=(1,1,batch_size,Npix,Npix), dtype=torch.float32, device='cuda')
    
    # Timing using CUDA events, this is more accurate for fast events than time_sync, beacuse time.time() is ~ 15ms resolution on Windows
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    measurements_up = interpolate(measurements, scale_factor = scale_factor, mode = 'area')
    end_event.record()
    
    # Wait for the events to complete
    torch.cuda.synchronize()
    times.append(start_event.elapsed_time(end_event)/1000) # cuda.Event's unit is ms
        
print(f"measurements.shape = {measurements.shape}")
print(f"measurements_up.shape = {measurements_up.shape}")
print(f"min = {np.min(times):.3g}, mean = {np.mean(times):.3g}, max = {np.max(times):.3g}, std = {np.std(times):.3g}")

plt.figure()
plt.hist(times, bins=20)
plt.show()