Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
2 parents 912ca31 + 4dbb736 commit b6c73b8
Show file tree
Hide file tree
Showing 56 changed files with 1,601 additions and 947 deletions.
47 changes: 27 additions & 20 deletions .github/scripts/get_workflow_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from argparse import ArgumentParser
from typing import Any
from typing import Any, Tuple

from github import Auth, Github
from github.Issue import Issue
Expand All @@ -9,6 +9,8 @@
WORKFLOW_LABEL_META = "" # use meta runners
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
LABEL_TYPE_KEY = "label_type"
MESSAGE_KEY = "message"
MESSAGE = "" # Debug message to return to the caller


def parse_args() -> Any:
Expand Down Expand Up @@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool:
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}


def get_workflow_type(issue: Issue, username: str) -> str:
def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]:
try:
user_list = issue.get_comments()[0].body.split()

if user_list[0] == "!":
print("LF Workflows are disabled for everyone. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = "LF Workflows are disabled for everyone. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
elif user_list[0] == "*":
print("LF Workflows are enabled for everyone. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = "LF Workflows are enabled for everyone. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
elif username in user_list:
print(f"LF Workflows are enabled for {username}. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
else:
print(f"LF Workflows are disabled for {username}. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
except Exception as e:
print(
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
)
return WORKFLOW_LABEL_META
MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
return WORKFLOW_LABEL_META, MESSAGE


def main() -> None:
args = parse_args()

if is_exception_branch(args.github_branch):
print(f"Exception branch: '{args.github_branch}', using meta runners")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners",
}
else:
try:
gh = get_gh_client(args.github_token)
# The default issue we use - https://github.com/pytorch/test-infra/issues/5132
issue = get_issue(gh, args.github_repo, args.github_issue)

output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)}
label_type, message = get_workflow_type(issue, args.github_user)
output = {
LABEL_TYPE_KEY: label_type,
MESSAGE_KEY: message,
}
except Exception as e:
print(f"Failed to get issue. Falling back to meta runners. Exception: {e}")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}",
}

json_output = json.dumps(output)
print(json_output)
Expand Down
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ include_patterns = [
'aten/src/ATen/native/cudnn/*.cpp',
'c10/**/*.h',
'c10/**/*.cpp',
'distributed/c10d/*SymmetricMemory.*',
'torch/csrc/**/*.h',
'torch/csrc/**/*.hpp',
'torch/csrc/**/*.cpp',
Expand Down
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ cc_library(
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
Expand Down
9 changes: 2 additions & 7 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base,
return Tensor();
}

Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx, const c10::optional<Tensor>& min_seqlen, const c10::optional<Tensor>& max_seqlen) {
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx) {
auto values = at::_nested_get_values(mutated_view);
if (inverse_return_mode != InverseReturnMode::NeverView) {
return values;
Expand All @@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const
auto lengths = at::_nested_get_lengths(base);
auto ragged_idx = at::_nested_get_ragged_idx(base);
auto dummy = at::_nested_get_jagged_dummy(base);
auto min_seqlen = at::_nested_get_min_seqlen(base);
auto max_seqlen = at::_nested_get_max_seqlen(base);
auto nt = at::_nested_view_from_jagged(
mutated_view, offsets, dummy, lengths, ragged_idx,
(min_seqlen.defined() ? c10::optional<Tensor>(min_seqlen) : c10::nullopt),
(max_seqlen.defined() ? c10::optional<Tensor>(max_seqlen) : c10::nullopt));
auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx);

if (inverse_return_mode != InverseReturnMode::NeverView) {
return nt;
Expand Down
8 changes: 7 additions & 1 deletion aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,13 @@ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val)

/* Special case fp32 atomic. */
#if defined(USE_ROCM)
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
#if defined(__gfx908__)
atomicAddNoRet(address, val);
#else
(void)unsafeAtomicAdd(address, val);
#endif
}
#else
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
#endif
Expand Down
14 changes: 2 additions & 12 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6185,12 +6185,12 @@
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
autogen: _nested_view_from_buffer_copy.out

- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
variants: function
device_check: NoCheck
tags: view_copy
Expand Down Expand Up @@ -6227,16 +6227,6 @@
device_check: NoCheck
dispatch: {}

- func: _nested_get_min_seqlen(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_get_max_seqlen(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
category_override: dummy
dispatch: {}
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/transformer/score_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128, 256]
head_dims = [64, 128]
dtypes = [
torch.bfloat16,
]
Expand Down Expand Up @@ -302,8 +302,6 @@ def main(dynamic: bool, calculate_bwd: bool):
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(Experiment(config, run_single_experiment(config)))

print_results(results)

Expand Down
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
"torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
"torch/csrc/distributed/c10d/Store.cpp",
"torch/csrc/distributed/c10d/SymmetricMemory.cpp",
"torch/csrc/distributed/c10d/TCPStore.cpp",
"torch/csrc/distributed/c10d/TCPStoreBackend.cpp",
"torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp",
Expand Down Expand Up @@ -684,6 +685,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
Expand Down
19 changes: 11 additions & 8 deletions c10/cuda/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
} \
} while (0)

#define C10_LIBCUDA_DRIVER_API(_) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
#define C10_LIBCUDA_DRIVER_API(_) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
_(cuMemGetAllocationGranularity) \
_(cuMemExportToShareableHandle) \
_(cuMemImportFromShareableHandle) \
_(cuGetErrorString)

#define C10_NVML_DRIVER_API(_) \
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ if(USE_CUDA)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
Expand Down
82 changes: 71 additions & 11 deletions test/distributed/_composable/fsdp/test_fully_shard_overlap.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Owner(s): ["oncall: distributed"]

import functools
from typing import Callable

import torch
import torch.distributed as dist
import torch.nn as nn

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
Expand All @@ -23,15 +25,6 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_fully_shard_training_overlap(self):
class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))

torch.manual_seed(42)

# Use non-trivial comm. time but still shorter than compute time
Expand All @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
fully_shard(model, reshard_after_forward=True)

orig_all_gather_into_tensor = dist.all_gather_into_tensor
orig_reduce_scatter = dist.reduce_scatter_tensor
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
comm_stream = torch.cuda.Stream()

def delay_collective():
Expand All @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs):

def delayed_reduce_scatter(*args, **kwargs):
delay_collective()
return orig_reduce_scatter(*args, **kwargs)
return orig_reduce_scatter_tensor(*args, **kwargs)

inp = torch.randn((2, dim), device="cuda")
loss = model(inp).sum() # warmup CUDA and allocator
Expand Down Expand Up @@ -92,6 +85,63 @@ def fwd_bwd():
)
self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time)

@skip_if_lt_x_gpu(2)
def test_fully_shard_post_optim_event_overlap(self):
torch.manual_seed(42)

# Use non-trivial comm. time but still shorter than compute time
dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10)
# Define the model to have a high-compute linear followed by a
# low-compute linear, where only the low-compute linear uses FSDP
model = nn.Sequential(
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
).cuda()
fully_shard(model[1], reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

orig_all_gather_into_tensor = dist.all_gather_into_tensor

def delayed_all_gather(*args, **kwargs):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
return orig_all_gather_into_tensor(*args, **kwargs)

inp = torch.randn((2, dim), device="cuda")

def run_train_steps(num_iters: int, use_post_optim_event: bool):
for _ in range(num_iters):
optim.zero_grad()
with patch_all_gather(delayed_all_gather):
loss = model(inp).sum()
loss.backward()
with implicit_replication():
optim.step()
if use_post_optim_event:
post_optim_event = torch.cuda.current_stream().record_event()
model[1].set_post_optim_event(post_optim_event)

run_train_steps(1, False) # warmup CUDA and allocator
num_iters = 5
baseline_time = self._time_fn(
functools.partial(run_train_steps, num_iters, False)
)
test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True))

buffer_ms = 4 # CPU delays and copies
# Baseline: FSDP all-gather is exposed since the FSDP module waits for
# the current stream and hence the high-compute linear
self.assertLessEqual(
baseline_time,
num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms),
)
# Test: FSDP all-gather is overlapped with the high-compute linear
# since the FSDP module only waits for the post-optim event (except on
# the 1st iteration when no event has been recorded)
expected_test_time = (
num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms
)
self.assertLessEqual(test_time, expected_test_time)
self.assertGreater(baseline_time, expected_test_time)

def _time_fn(self, fn: Callable):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor):
return grad_input, grad_weight, None


class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit b6c73b8

Please sign in to comment.