Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiler does not record CUDA times #124547

Closed
mmeendez8 opened this issue Apr 20, 2024 · 8 comments
Closed

Profiler does not record CUDA times #124547

mmeendez8 opened this issue Apr 20, 2024 · 8 comments
Labels
oncall: profiler profiler-related issues (cpu, gpu, kineto)

Comments

@mmeendez8
Copy link

mmeendez8 commented Apr 20, 2024

🐛 Describe the bug

import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity, schedule
from torch import Tensor

def my_normalize(input: Tensor, mean: Tensor, std: Tensor):
    mean = mean.view(-1, 1, 1)
    std = std.view(-1, 1, 1)
    return (input - mean) / std

image = torch.randn(1, 3, 224, 224)
mean = torch.tensor([123.675, 116.28, 103.53])
std = torch.tensor([58.395, 57.12, 57.375])

with torch.profiler.profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1,  warmup=9, active=90, repeat=1),
    record_shapes=True,
) as prof:
    for i in range(100):
        my_normalize(image, mean, std)
        prof.step() 

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

I got the following result:

-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
    ProfilerStep*        41.98%       8.486ms       100.00%      20.212ms     224.578us            90  
        aten::div        33.01%       6.671ms        33.01%       6.671ms      74.122us            90  
        aten::sub        20.93%       4.230ms        20.93%       4.230ms      47.000us            90  
       aten::view         4.08%     825.000us         4.08%     825.000us       4.583us           180  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.212ms

STAGE:2024-04-20 10:54:39 32358:32358 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-20 10:54:39 32358:32358 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-20 10:54:39 32358:32358 ActivityProfilerController.cpp:324] Completed Stage: Post Processing

Versions

Collecting environment information...
PyTorch version: 2.2.2+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.11.3 (main, May  3 2023, 11:11:08) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia driver version: 531.79
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      39 bits physical, 48 bits virtual
CPU(s):                             12
On-line CPU(s) list:                0-11
Thread(s) per core:                 2
Core(s) per socket:                 6
Socket(s):                          1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              165
Model name:                         Intel(R) Core(TM) i7-10750H CPU @ 2.60GHz
Stepping:                           2
CPU MHz:                            2591.990
BogoMIPS:                           5183.98
Virtualization:                     VT-x
Hypervisor vendor:                  Microsoft
Virtualization type:                full
L1d cache:                          192 KiB
L1i cache:                          192 KiB
L2 cache:                           1.5 MiB
L3 cache:                           12 MiB
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Unknown: Dependent on hypervisor status
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt xsaveopt xsavec xgetbv1 xsaves md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.2.2+cu118
[pip3] torchaudio==2.2.2+cu118
[pip3] torchvision==0.17.2+cu118
[pip3] triton==2.2.0
[conda] Could not collect

cc @robieta @chaekit @aaronenyeshi @guotuofeng @guyang3532 @dzhulgakov @davidberard98 @briancoutinho @sraikund16 @sanrise

@colesbury colesbury added the oncall: profiler profiler-related issues (cpu, gpu, kineto) label Apr 22, 2024
@sraikund16
Copy link
Contributor

@mmeendez8 This is because you are not allocating any of the data onto the device. If you use

device = torch.device("cuda")
image = torch.randn(1, 3, 224, 224).to(device)
mean = torch.tensor([123.675, 116.28, 103.53]).to(device)
std = torch.tensor([58.395, 57.12, 57.375]).to(device)

You will see the following chart:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        35.39%      27.286ms        99.96%      77.064ms     856.261us       0.000us         0.00%     626.907us       6.966us            90  
                                              aten::sub        27.41%      21.128ms        30.83%      23.767ms     264.073us     310.845us         1.22%     310.845us       3.454us            90  
                                              aten::div        24.19%      18.646ms        26.81%      20.668ms     229.642us     316.062us         1.24%     316.062us       3.512us            90  
                                             aten::view         6.93%       5.343ms         6.93%       5.343ms      29.686us       0.000us         0.00%       0.000us       0.000us           180  
                                       cudaLaunchKernel         6.04%       4.660ms         6.04%       4.660ms      25.889us       0.000us         0.00%       0.000us       0.000us           180  
                                  cudaDeviceSynchronize         0.04%      28.970us         0.04%      28.970us      28.970us       0.000us         0.00%       0.000us       0.000us             1  
                                          ProfilerStep*         0.00%       0.000us         0.00%       0.000us       0.000us      24.909ms        97.54%      24.909ms     276.761us            90  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     310.845us         1.22%     310.845us       3.454us            90  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     316.062us         1.24%     316.062us       3.512us            90  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 77.092ms
Self CUDA time total: 25.535ms

@sraikund16
Copy link
Contributor

Closing because this is just a user error

@mmeendez8
Copy link
Author

mmeendez8 commented May 10, 2024

I am sorry @sraikund16 I copied the CPU example from my code instead of the CUDA one.

This is what I was trying:

import torch
from torch.profiler import ProfilerActivity, schedule
from torch import Tensor

def my_normalize(input: Tensor, mean: Tensor, std: Tensor):
    mean = mean.view(-1, 1, 1)
    std = std.view(-1, 1, 1)
    return (input - mean) / std

device = torch.device("cuda")

image_cuda = image.to(device,)
mean_cuda = mean.to(device)
std_cuda = std.to(device)

with torch.profiler.profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1,  warmup=9, active=90, repeat=1),
    record_shapes=True,
) as prof:    
    for i in range(1000):
        r = my_normalize(image_cuda, mean_cuda, std_cuda)
        prof.step() 

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

And this is the output I get:

-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
    ProfilerStep*        50.64%       5.776ms       100.00%      11.407ms     126.744us            90  
        aten::sub        27.22%       3.105ms        27.22%       3.105ms      34.500us            90  
        aten::div        17.99%       2.052ms        17.99%       2.052ms      22.800us            90  
       aten::view         4.16%     474.000us         4.16%     474.000us       2.633us           180  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 11.407ms

STAGE:2024-05-10 08:57:12 30199:30199 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-10 08:57:13 30199:30199 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-10 08:57:13 30199:30199 ActivityProfilerController.cpp:324] Completed Stage: Post Processing

Also if I change to the experimental profiler (as I saw mentioned in a couple issues) and run this code:

import torch
from torch.profiler import ProfilerActivity, schedule
from torch import Tensor

def my_normalize(input: Tensor, mean: Tensor, std: Tensor):
    mean = mean.view(-1, 1, 1)
    std = std.view(-1, 1, 1)
    return (input - mean) / std

device = torch.device("cuda")

image_cuda = image.to(device,)
mean_cuda = mean.to(device)
std_cuda = std.to(device)

with profile(with_stack=True, 
             profile_memory=True, 
             experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True), 
             schedule=schedule(wait=1,  warmup=9, active=90, repeat=1)) as prof:   
    for i in range(1000):
        r = my_normalize(image_cuda, mean_cuda, std_cuda)
        prof.step() 

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

I get a different output that includes CUDA memory usage but not CUDA times:

STAGE:2024-05-10 09:01:10 30199:30199 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-05-10 09:01:10 30199:30199 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-10 09:01:10 30199:30199 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
    ProfilerStep*        55.61%       5.844ms       100.00%      10.509ms     116.767us           0 b           0 b           0 b    -103.36 Mb            90  
        aten::sub        25.72%       2.703ms        25.72%       2.703ms      30.033us           0 b           0 b      51.68 Mb      51.68 Mb            90  
        aten::div        15.54%       1.633ms        15.54%       1.633ms      18.144us           0 b           0 b      51.68 Mb      51.68 Mb            90  
       aten::view         3.13%     329.000us         3.13%     329.000us       1.828us           0 b           0 b           0 b           0 b           180  
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 10.509ms

@lu-renjie
Copy link

I have the same problem as you, have you solved this problem?

@mmeendez8
Copy link
Author

Not at all @lu-renjie

@sraikund16
Copy link
Contributor

sraikund16 commented Jun 10, 2024

@mmeendez8 Sorry for missing this. I ran the following block of code:

import torch
from torch.profiler import ProfilerActivity, schedule
from torch import Tensor

def my_normalize(input: Tensor, mean: Tensor, std: Tensor):
    mean = mean.view(-1, 1, 1)
    std = std.view(-1, 1, 1)
    return (input - mean) / std

device = torch.device("cuda")
image = torch.randn(1, 3, 224, 224)
mean = torch.tensor([123.675, 116.28, 103.53])
std = torch.tensor([58.395, 57.12, 57.375])
image_cuda = image.to(device,)
mean_cuda = mean.to(device)
std_cuda = std.to(device)

with torch.profiler.profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1,  warmup=9, active=90, repeat=1),
    record_shapes=True,
) as prof:
    for i in range(1000):
        r = my_normalize(image_cuda, mean_cuda, std_cuda)
        prof.step()

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

and this was the result:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        48.96%      26.057ms        99.84%      53.133ms     590.362us       0.000us         0.00%     781.047us       8.678us            90  
                                              aten::sub        19.42%      10.334ms        23.07%      12.276ms     136.401us     386.268us         2.71%     386.268us       4.292us            90  
                                              aten::div        15.76%       8.385ms        18.45%       9.821ms     109.125us     394.779us         2.77%     394.779us       4.386us            90  
                                             aten::view         9.35%       4.978ms         9.35%       4.978ms      27.658us       0.000us         0.00%       0.000us       0.000us           180  
                                       cudaLaunchKernel         6.35%       3.379ms         6.35%       3.379ms      18.772us       0.000us         0.00%       0.000us       0.000us           180  
                                  cudaDeviceSynchronize         0.16%      86.414us         0.16%      86.414us      86.414us       0.000us         0.00%       0.000us       0.000us             1  
                                          ProfilerStep*         0.00%       0.000us         0.00%       0.000us       0.000us      13.458ms        94.51%      13.458ms     149.531us            90  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     386.268us         2.71%     386.268us       4.292us            90  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     394.779us         2.77%     394.779us       4.386us            90  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 53.219ms
Self CUDA time total: 14.239ms

@lu-renjie
Copy link

lu-renjie commented Jun 10, 2024 via email

@sraikund16 sraikund16 reopened this Jun 10, 2024
@sraikund16
Copy link
Contributor

@mmeendez8 on second thought it sounds like you may have not installed kineto since the same code block works for me. Please check here: https://github.com/pytorch/kineto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: profiler profiler-related issues (cpu, gpu, kineto)
Projects
None yet
Development

No branches or pull requests

4 participants