Skip to content

Eager CUDAGraph + stream performance #161104

@BoyuanFeng

Description

@BoyuanFeng

In the following repro, with warmup in block1 and block2, the latency of adding two cuda tensors x+y is 0.0013 ms. However, if we comment out either block1 or block2, the latency reduces to 0.0011 ms. This is unexpected since there are torch.cuda.synchronize() after block1 and block2.

Repro:

import torch


x = torch.randn(1024, device="cuda", requires_grad=False)
y = torch.randn(1024, device="cuda", requires_grad=False)


def fn():
    return x + y


# warmup
############################# <---- Block1: comment out this line to reduce latency from 0.0013 ms to 0.0011 ms
for _ in range(100):
    fn()
############################

torch.cuda.synchronize()

with torch.cuda.stream(torch.cuda.Stream()):
    # warmup
    ############################# <---- Block2: comment out this line to reduce latency from 0.0013 ms to 0.0011 ms
    ############################# <---- comment out 1 of option1/option2 to reduce latency from 0.0013 ms to 0.0011 ms
    for _ in range(6):
        fn()
    ############################
    n_repeat = 2000

    torch.cuda.synchronize()

    # construct a cuda graph with `n_repeat` unrolled function calls to minimize
    # host overhead
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        for _ in range(n_repeat):
            fn()

    # measure time and return
    ret = []
    n_retries = 10
    for _ in range(n_retries):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        g.replay()
        end_event.record()
        torch.cuda.synchronize()
        measured_latency = start_event.elapsed_time(end_event) / n_repeat
        ret.append(measured_latency)

compiler_latency = sum(ret) / len(ret)
print("compiler_latency:", compiler_latency)

cc @msaroufim @jerryzh168 @ptrblck @eqy @mcarilli @ezyang @eellison @penguinwu @galv

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: cuda graphsAbility to capture and then replay streams of CUDA kernelsmodule: performanceIssues related to performance, either of kernel code or framework glueneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions