In [27]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

# Check CUDA availability
assert torch.cuda.is_available(), "CUDA device not found!"

# Create basic tensors with profiling
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("Tensor Creation"):
        # Explicitly create tensors on GPU
        a = torch.tensor([2, 4, 5], device='cuda')
        e = torch.tensor([[1, 2, 3], [4, 5, 6]], device='cuda')
        f = torch.tensor([7, 8, 9], device='cuda')
        g = torch.randn(2, 3, device='cuda')
        h = torch.randn(3, 2, device='cuda')
        x = torch.rand(3, 4, device='cuda')
    
    with record_function("Element-wise Operations"):
        result = e * f  # Now on GPU
    
    with record_function("Matrix Multiplication"):
        product = g @ h  # GPU matrix multiplication
    
    with record_function("Property Checks"):
        shape = x.shape
        dtype = x.dtype
        device = x.device

# Print profiling results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=10
))

# Export trace
prof.export_chrome_trace("gpu_tensor_ops_trace.json")


---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      Matrix Multiplication         0.08%     405.600us        67.85%     328.666ms     328.666ms     352.000us         0.08%     328.690ms     328.690ms         -16 b         -16 b       8.13 Mb           0 b             1  
               aten::matmul         0.02%      89.900us        67.76%     328.260ms     328.26