Skip to content

Commit

Permalink
Update on "[FSDP2] Fixed 2D clip grad norm test"
Browse files Browse the repository at this point in the history
This fixes #126484.

We change from transformer to MLP stack since transformer seems to introduce slight numeric differences when using TP. We include a sequence parallel layer norm module in the MLP stack to exercise `(S(0), R)` placement.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
  • Loading branch information
awgu committed May 21, 2024
2 parents 8a1cce1 + afd4a64 commit e09cfec
Show file tree
Hide file tree
Showing 110 changed files with 2,746 additions and 1,038 deletions.
3 changes: 3 additions & 0 deletions .ci/docker/centos-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ RUN rm install_rocm.sh
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh
RUN rm install_rocm_magma.sh
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh
ENV PATH /opt/rocm/bin:$PATH
ENV PATH /opt/rocm/hcc/bin:$PATH
ENV PATH /opt/rocm/hip/bin:$PATH
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/common/install_amdsmi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

set -ex

cd /opt/rocm/share/amd_smi && pip install .
6 changes: 4 additions & 2 deletions .ci/docker/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ install_ubuntu() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
Expand Down Expand Up @@ -106,7 +107,8 @@ install_centos() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

# precompiled miopen kernels; search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/ubuntu-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

# Install amdsmi
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh

# (optional) Install non-default CMake version
ARG CMAKE_VERSION
COPY ./common/install_cmake.sh install_cmake.sh
Expand Down
22 changes: 16 additions & 6 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ test_inductor_distributed() {
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm.py -k test_clip_grad_norm_2d --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py -k test_clip_grad_norm_2d --verbose
python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose

# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
Expand All @@ -351,10 +351,20 @@ test_inductor() {

test_inductor_cpp_wrapper_abi_compatible() {
export TORCHINDUCTOR_ABI_COMPATIBLE=1
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"

echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1"
# cpu stack allocation causes segfault and needs more investigation
TORCHINDUCTOR_STACK_ALLOCATION=0 python test/run_test.py --include inductor/test_cpu_cpp_wrapper
python test/run_test.py --include inductor/test_cuda_cpp_wrapper

TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \
--training --inductor --disable-cudagraphs --only vit_base_patch16_224 \
--output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv"
python benchmarks/dynamo/check_accuracy.py \
--actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv"
}

# "Global" flags for inductor benchmarking controlled by TEST_CONFIG
Expand Down Expand Up @@ -557,12 +567,12 @@ test_inductor_torchbench_smoketest_perf() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"

# smoke test the cpp_wrapper mode
TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy --bfloat16 \
--inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_smoketest.csv"
# Test some models in the cpp wrapper mode
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \
--bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv"
python benchmarks/dynamo/check_accuracy.py \
--actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_smoketest.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv"
--actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv"

python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \
Expand Down
4 changes: 0 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@
ignore = dirty
path = third_party/gemmlowp/gemmlowp
url = https://github.com/google/gemmlowp.git
[submodule "third_party/QNNPACK"]
ignore = dirty
path = third_party/QNNPACK
url = https://github.com/pytorch/QNNPACK
[submodule "third_party/neon2sse"]
ignore = dirty
path = third_party/neon2sse
Expand Down
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/csrc/distributed/c10d/*.hpp",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ namespace {
// directly against incoming TensorImpl*s.
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
ska::flat_hash_map<TensorImpl*, val_type> cached_casts;

static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
}
std::mutex cached_casts_mutex;


Expand Down Expand Up @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true;

void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
cached_casts.clear();
get_cached_casts().clear();
}

int increment_nesting() {
Expand Down Expand Up @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_

if (can_try_cache) {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
if (it != get_cached_casts().end()) {
return std::get<1>(it->second);
} else {
auto casted_arg = arg.to(to_type);
cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
return casted_arg;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void CUDAGraph::debug_dump(const std::string& debug_path) {
TORCH_WARN("DEBUG: calling debug_dump()");
if (has_graph_) {
TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
}
} else {
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
bool invert,
const Tensor& out,
string op_name) {
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"isin_Tensor_Tensor_out supported on MPS from MacOs_14_0 onwards");
if (elements.numel() == 0) {
return;
}
Expand All @@ -295,6 +293,10 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,

TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
TORCH_CHECK(elements.dtype() == test_elements.dtype());
TORCH_CHECK(
!(!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) && !supportedFloatingType(elements.scalar_type())),
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
elements.scalar_type());

@autoreleasepool {
string key =
Expand Down
29 changes: 29 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,18 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with TorchInductor",
)
group.add_argument(
"--quantization",
choices=[
"int8dynamic",
"int8weightonly",
"int4weightonly",
"autoquant",
"noquant",
],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
)
group.add_argument(
"--export",
action="store_true",
Expand Down Expand Up @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None):
if args.inductor:
assert args.backend is None
args.backend = "inductor"
if args.quantization:
assert args.backend is None
args.backend = "torchao"
if args.dynamic_batch_only:
args.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True
Expand Down Expand Up @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None):

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
elif args.backend == "torchao":
assert "cuda" in args.devices, "Quantization requires CUDA device."
assert args.bfloat16, "Quantization requires dtype bfloat16."
from .torchao import setup_baseline, torchao_optimize_ctx

setup_baseline()
baseline_ctx = functools.partial(
torch.compile,
backend="inductor",
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = speedup_experiment
Expand Down
Loading

0 comments on commit e09cfec

Please sign in to comment.