-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
When I use different flash versions, the speed is almost the same. From my understanding, shouldn't v2 be faster than v1?
PyTorch version:
v1:'torch2.0.1+cu117'
v2:'2.2.0a0+git18e1a37 from source install on today'
this my test code
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
distributed.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
backbone = ViT_S_16().cuda()
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone,
bucket_cap_mb=32,
find_unused_parameters=True,
static_graph=True)
auto_scaler = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=200)
ce = torch.nn.CrossEntropyLoss().cuda()
opt = torch.optim.AdamW(backbone.parameters(), 0.001)
t = time.time()
for i in tqdm(range(1000)):
img = torch.randn(512,3,224,224,device=local_rank)
local_labels = torch.zeros(size=(512,), dtype=torch.long, device=local_rank)
with torch.cuda.amp.autocast(True):
result = backbone(img)
loss = ce(result, local_labels)
auto_scaler.scale(loss).backward()
auto_scaler.unscale_(opt)
auto_scaler.step(opt)
auto_scaler.update()
opt.zero_grad()
print((time.time()-t)/1000)
Metadata
Metadata
Assignees
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module