Skip to content

Commit

Permalink
Update on "[FSDP2] allow meta tensors during loading state dict and c…
Browse files Browse the repository at this point in the history
…pu offloading"


unit test: ``pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py``

with meta init and cpu offloading, we have meta tensors after`model.load_state_dict(assign=True, strict=False)`. This PR avoided calling `.cpu` on meta tensors otherwise it's a runtime error

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
  • Loading branch information
weifengpy committed May 15, 2024
2 parents 12d6770 + 80d9f93 commit 5eb2e7d
Show file tree
Hide file tree
Showing 68 changed files with 3,163 additions and 565 deletions.
47 changes: 45 additions & 2 deletions .ci/docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,27 @@ fi
# CMake 3.18 is needed to support CUDA17 language variant
CMAKE_VERSION=3.18.5

_UCX_COMMIT=00bcc6bb18fc282eb160623b4c0d300147f579af
_UCC_COMMIT=7cb07a76ccedad7e56ceb136b865eb9319c258ea
_UCX_COMMIT=7bb2722ff2187a0cad557ae4a6afa090569f83fb
_UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b

# It's annoying to rename jobs every time you want to rewrite a
# configuration, so we hardcode everything here rather than do it
# from scratch
case "$image" in
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9)
CUDA_VERSION=12.4.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
PROTOBUF=yes
DB=yes
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9)
CUDA_VERSION=12.1.1
CUDNN_VERSION=8
Expand All @@ -105,6 +119,21 @@ case "$image" in
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks)
CUDA_VERSION=12.4.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
PROTOBUF=yes
DB=yes
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
CONDA_CMAKE=yes
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks)
CUDA_VERSION=12.1.1
CUDNN_VERSION=8
Expand Down Expand Up @@ -134,6 +163,20 @@ case "$image" in
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9)
CUDA_VERSION=12.4.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
PROTOBUF=yes
DB=yes
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9)
CUDA_VERSION=12.1.1
CUDNN_VERSION=8
Expand Down
5 changes: 4 additions & 1 deletion .ci/docker/common/install_cudnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ if [[ ${CUDNN_VERSION} == 8 ]]; then
# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
mkdir tmp_cudnn
pushd tmp_cudnn
if [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then
if [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-8.9.7.29_cuda12-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz
elif [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-8.9.2.26_cuda12-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz
elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then
Expand Down
11 changes: 8 additions & 3 deletions .ci/docker/common/install_cusparselt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ set -ex
# cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html
mkdir tmp_cusparselt && cd tmp_cusparselt

if [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then
CUSPARSELT_NAME="libcusparse_lt-linux-x86_64-0.5.2.1-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/${CUSPARSELT_NAME}.tar.xz
if [[ ${CUDA_VERSION:0:4} =~ ^12\.[1-4]$ ]]; then
arch_path='sbsa'
export TARGETARCH=${TARGETARCH:-$(uname -m)}
if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then
arch_path='x86_64'
fi
CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.5.2.1-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz
elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then
CUSPARSELT_NAME="libcusparse_lt-linux-x86_64-0.4.0.7-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/${CUSPARSELT_NAME}.tar.xz
Expand Down
1 change: 1 addition & 0 deletions .ci/docker/ubuntu-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ RUN rm install_cusparselt.sh
RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi
RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi
RUN if [ -h /usr/local/cuda-12.1/cuda-12.1 ]; then rm /usr/local/cuda-12.1/cuda-12.1; fi
RUN if [ -h /usr/local/cuda-12.1/cuda-12.4 ]; then rm /usr/local/cuda-12.1/cuda-12.4; fi

USER jenkins
CMD ["bash"]
9 changes: 9 additions & 0 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,15 @@ test_inductor_torchbench_smoketest_perf() {
"$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv" \
--expected benchmarks/dynamo/expected_ci_perf_inductor_torchbench.csv
done

# Perform some "warm-start" runs for a few huggingface models.
for test in AlbertForQuestionAnswering AllenaiLongformerBase DistilBertForMaskedLM DistillGPT2 GoogleFnet YituTechConvBert; do
python benchmarks/dynamo/huggingface.py --accuracy --training --amp --inductor --device cuda --warm-start-latency \
--only $test --output "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv"
python benchmarks/dynamo/check_accuracy.py \
--actual "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv"
done
}

test_inductor_torchbench_cpu_smoketest_perf(){
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/audio.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ea437b31ce316ea3d66fe73768c0dcb94edb79ad
1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0
2 changes: 2 additions & 0 deletions .github/workflows/docker-builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:
matrix:
runner: [linux.12xlarge]
docker-image-name: [
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9,
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks,
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9,
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks,
pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9,
Expand Down
6 changes: 0 additions & 6 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1053,12 +1053,6 @@ exclude_patterns = [
'test/quantization/fx/test_quantize_fx.py',
'test/quantization/fx/test_subgraph_rewriter.py',
'test/test_datapipe.py',
'test/test_fake_tensor.py',
'test/test_flop_counter.py',
'test/test_function_schema.py',
'test/test_functional_autograd_benchmark.py',
'test/test_functional_optim.py',
'test/test_functionalization_of_rng_ops.py',
'test/test_futures.py',
'test/test_fx.py',
'test/test_fx_experimental.py',
Expand Down
115 changes: 112 additions & 3 deletions aten/src/ATen/native/mps/operations/Quantized.mm
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <ATen/native/mps/OperationUtils.h>
#include <fmt/format.h>

// #define _CAPTURE_KERNEL 1

namespace at::native {

using namespace mps;
Expand Down Expand Up @@ -82,6 +84,85 @@ kernel void int4pack_mm(
INSTANTIATE_INT4MM(bfloat, 128);
INSTANTIATE_INT4MM(bfloat, 256);
#endif
template<typename T>
struct Vec4Type {};
template<>
struct Vec4Type<float> {
using type = float4;
};
template<>
struct Vec4Type<half> {
using type = half4;
};
#if __METAL_VERSION__ >= 310
template<>
struct Vec4Type<bfloat> {
using type = bfloat4;
};
#endif
template <typename T, unsigned blockSize=8>
kernel void
int8pack_mm(constant T *A [[buffer(0)]], constant char *B [[buffer(1)]],
constant T *scales [[buffer(2)]],
device T *outputData [[buffer(3)]],
constant int3 &sizes [[buffer(4)]],
uint2 group_index [[threadgroup_position_in_grid]],
uint2 threadgroup_index [[thread_position_in_threadgroup]]) {
using vecT = typename Vec4Type<T>::type;
const uint lda = sizes.y;
const uint ldc = sizes.z;
int out_idx = (group_index.x * blockSize + threadgroup_index.x) * 4;
int n = out_idx % sizes.z;
int m = out_idx / sizes.z;
// Offset pointers
A += m * lda;
B += n * lda;
outputData += m *ldc;
float4 rc = 0;
for (unsigned k = threadgroup_index.y * 4; k < sizes.y; k += 4 * blockSize) {
threadgroup_barrier(mem_flags::mem_none);
auto a_val = float4(*reinterpret_cast<constant vecT *>(A + k));
float4x4 b_val;
for (int i = 0; i < 4; ++i) {
b_val[i] = float4(*reinterpret_cast<constant char4 *>(B + i * lda + k));
}
rc += transpose(b_val) * a_val;
}
// Accumulate results acorss SIMD group? (8 threads using vec4)
threadgroup float4 tgp_memory[blockSize][blockSize];
tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (threadgroup_index.y == 0) {
for (int i = 1; i < blockSize; i++) {
rc += tgp_memory[threadgroup_index.x][i];
}
*reinterpret_cast<device vecT *>(outputData + n) =
vecT(rc * float4(*reinterpret_cast<constant vecT *>(scales + n)));
}
}
#define INSTANTIATE_INT8MM(DTYPE) \
template [[host_name("int8pack_mm_" #DTYPE)]] kernel void \
int8pack_mm<DTYPE>( \
constant DTYPE * A [[buffer(0)]], constant char *B [[buffer(1)]], \
constant DTYPE *scales [[buffer(2)]], \
device DTYPE *outputData [[buffer(3)]], \
constant int3 &sizes [[buffer(4)]], \
uint2 group_index [[threadgroup_position_in_grid]], \
uint2 threadgroup_index [[thread_position_in_threadgroup]]);
INSTANTIATE_INT8MM(half);
INSTANTIATE_INT8MM(float);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT8MM(bfloat);
#endif
)METAL_QUANTIZED");

Tensor _weight_int4pack_mm_mps(const Tensor& A, const Tensor& B, int64_t qGroupSize, const Tensor& qScaleAndZeros) {
Expand Down Expand Up @@ -114,8 +195,7 @@ Tensor _weight_int4pack_mm_mps(const Tensor& A, const Tensor& B, int64_t qGroupS

auto C = at::empty({M, N}, A.options());
MPSStream* mpsStream = getCurrentMPSStream();
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(M), static_cast<uint32_t>(K), static_cast<uint32_t>(N)};
static bool firstCapture = false;
std::array<uint32_t, 4> sizes = {static_cast<uint32_t>(M), static_cast<uint32_t>(K), static_cast<uint32_t>(N), 0};
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
#if _CAPTURE_KERNEL
Expand Down Expand Up @@ -163,7 +243,35 @@ Tensor _weight_int8pack_mm_mps(const Tensor& A, const Tensor& B, const Tensor& s
TORCH_CHECK(scales.dim() == 1 && scales.size(0) == N, __func__, " : expect scales to be 1d tensor with size ", N);

auto C = at::empty({M, N}, A.options());

TORCH_CHECK(N % 32 == 0 && K % 32 == 0);
#if 1
MPSStream* mpsStream = getCurrentMPSStream();
std::array<uint32_t, 4> sizes = {static_cast<uint32_t>(M), static_cast<uint32_t>(K), static_cast<uint32_t>(N), 0};
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
#if _CAPTURE_KERNEL
if (getMPSProfiler().isCaptureEnabled()) {
getMPSProfiler().startCapture(fmt::format("int8pack_mm_{}x{}x{}", M, N, K), mpsStream);
}
#endif
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const std::string kernel = fmt::format("int8pack_mm_{}", scalarToMetalTypeString(A));
id<MTLComputePipelineState> quantizedPSO = lib.getPipelineStateForFunc(kernel);
[computeEncoder setComputePipelineState:quantizedPSO];
mtl_setBuffer(computeEncoder, A, 0);
mtl_setBuffer(computeEncoder, B, 1);
mtl_setBuffer(computeEncoder, scales, 2);
mtl_setBuffer(computeEncoder, C, 3);
[computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
[computeEncoder dispatchThreads:MTLSizeMake(M * N / 4, 8, 1) threadsPerThreadgroup:MTLSizeMake(8, 8, 1)];
#if _CAPTURE_KERNEL
if (getMPSProfiler().isCapturing()) {
getMPSProfiler().stopCapture(mpsStream);
}
#endif
}
});
#else
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *ATensor = nil, *BTensor = nil, *scalesTensor = nil;
Expand Down Expand Up @@ -193,6 +301,7 @@ Tensor _weight_int8pack_mm_mps(const Tensor& A, const Tensor& B, const Tensor& s
dictionaryFromPlaceholders(APlaceholder, BPlaceholder, scalesPlaceholder),
outputPlaceholder);
}
#endif

return C;
}
Expand Down
24 changes: 19 additions & 5 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
current_onnx_compiler = ""
current_batch_size = None
output_filename = None
disable_output = False

MAX_DOWNLOAD_ATTEMPTS = 5

Expand Down Expand Up @@ -306,6 +307,9 @@ def load_model_from_path(path_and_class_str):


def output_csv(filename, headers, row):
global disable_output
if disable_output:
return
if os.path.exists(filename):
with open(filename) as fd:
lines = list(csv.reader(fd)) or [[]]
Expand Down Expand Up @@ -3212,6 +3216,11 @@ def get_example_inputs(self):
"--output-directory",
help="Overrides the directory to place output files.",
)
parser.add_argument(
"--disable-output",
action="store_true",
help="Disable writing of output files, e.g., for warm-up runs",
)
parser.add_argument(
"--baseline",
help="Compare with a prior --output",
Expand Down Expand Up @@ -3391,6 +3400,7 @@ def get_example_inputs(self):
)
group_latency.add_argument(
"--warm-start-latency",
"--warm_start_latency",
action="store_true",
help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run",
)
Expand Down Expand Up @@ -3610,10 +3620,11 @@ def main(runner, original_dir=None, args=None):
cmd = [sys.executable] + sys.argv
cmd.remove("--warm-start-latency")

print(f"Executing cold-start run for {args.only}")
subprocess.check_call(cmd, timeout=args.timeout, env=env)
print(f"Performing cold-start run for {args.only}")
warmup_cmd = cmd + ["--repeat=1", "--disable-output"]
subprocess.check_call(warmup_cmd, timeout=args.timeout, env=env)

print(f"Executing warm-start run for {args.only}")
print(f"Performing warm-start run for {args.only}")
subprocess.check_call(cmd, timeout=args.timeout, env=env)
else:
# single process path just uses the main process
Expand Down Expand Up @@ -3666,7 +3677,7 @@ def run(runner, args, original_dir=None):
if args.ci:
if args.accuracy:
# Run fewer iterations when checking accuracy
args.repeat = 2
args.repeat = min(args.repeat, 2)

# Set translation validation on by default on CI accuracy runs.
torch.fx.experimental._config.translation_validation = True
Expand Down Expand Up @@ -3820,9 +3831,12 @@ def run(runner, args, original_dir=None):
runner.skip_models.clear()

experiment = null_experiment
global current_name, current_device, current_batch_size, output_filename, optimize_ctx, current_onnx_compiler
global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
optimize_ctx = contextlib.nullcontext()

if args.disable_output:
disable_output = True

if args.overhead:
optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
experiment = speedup_experiment
Expand Down
Loading

0 comments on commit 5eb2e7d

Please sign in to comment.