Skip to content

about flash attention v1 and v2 speed #114209

@daixiangzi

Description

@daixiangzi

When I use different flash versions, the speed is almost the same. From my understanding, shouldn't v2 be faster than v1?
PyTorch version:
v1:'torch2.0.1+cu117'
v2:'2.2.0a0+git18e1a37 from source install on today'
this my test code

   rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    backbone = ViT_S_16().cuda()
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone,
        bucket_cap_mb=32,
        find_unused_parameters=True,
        static_graph=True)
    auto_scaler = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=200)
    ce = torch.nn.CrossEntropyLoss().cuda()
    opt = torch.optim.AdamW(backbone.parameters(), 0.001)
    t = time.time()
    
    for i in tqdm(range(1000)):
        img = torch.randn(512,3,224,224,device=local_rank)
        local_labels = torch.zeros(size=(512,), dtype=torch.long, device=local_rank)
        with torch.cuda.amp.autocast(True):
                result = backbone(img)
                loss = ce(result, local_labels)
        auto_scaler.scale(loss).backward()
        auto_scaler.unscale_(opt)
        auto_scaler.step(opt)
        auto_scaler.update()
        opt.zero_grad()
    print((time.time()-t)/1000)

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions