In [1]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
def kde(samples, locations, bandwidth):
    print(f'samples shape: {samples.shape}')
    print(f'locations shape: {locations.shape}')
    print(f'bandwidth: {bandwidth}')

    all_samples = samples.reshape(samples.shape + (1,) * len(locations.shape[:-1]))
    print(f'all_samples shape: {all_samples.shape}')

    samples_copies = all_samples - torch.movedim(locations, -1, 0)
    print(f'samples_copies shape: {samples_copies.shape}')

    diff = torch.norm(
        samples_copies,
        dim=-len(locations.shape[:-1]) - 1,
    )
    print(f'diff shape: {diff.shape}')

    out = (-diff ** 2 / (2.0 * bandwidth ** 2)).exp().sum(dim=len(
                samples.shape)-2)
    print(f'out shape: {out.shape}')

    norm = out.flatten(start_dim=len(locations.shape)-2).sum(dim=-1)
    print(f'norm shape: {norm.shape}')

    result = out / norm.reshape(-1, *(1,)*(len(locations.shape)-1))
    print(f'result shape: {result.shape}')
    return result


In [5]:
samples = torch.ones((20, 1000, 2))
bins = torch.linspace(-30, 30, 200) * 1e-3
bandwidth = (bins[1]-bins[0]) / 2
xx = torch.meshgrid(bins, bins, indexing='ij') 
locations = torch.stack(xx, dim=-1)

with profile(activities=[ProfilerActivity.CPU],
        profile_memory=True) as prof:
    hist = kde(samples, locations, bandwidth)

prof.export_chrome_trace("kde_rectangle.json")

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

samples shape: torch.Size([20, 1000, 2])
locations shape: torch.Size([200, 200, 2])
bandwidth: 0.5
all_samples shape: torch.Size([20, 1000, 2, 1, 1])


STAGE:2023-07-02 12:16:35 138458:138458 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


samples_copies shape: torch.Size([20, 1000, 2, 200, 200])
diff shape: torch.Size([20, 1000, 200, 200])
out shape: torch.Size([20, 200, 200])
norm shape: torch.Size([20])
result shape: torch.Size([20, 200, 200])
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::sub        29.50%     401.357ms        29.50%     401.357ms     401.357ms       5.96 Gb       5.96 Gb             1  
                   aten::div         9.54%     129.718ms         9.54%     129.736ms      64.868ms       2.98 Gb       2.98 Gb             2  
    aten::linalg_vector_norm        29.03%     394.917ms        29.03%    

STAGE:2023-07-02 12:16:37 138458:138458 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-02 12:16:37 138458:138458 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
