Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7 from qingshui/paddle_v2.4.2
Browse files Browse the repository at this point in the history
add flash-attention,  CublasLt MulAndAdd
  • Loading branch information
laipaang committed Nov 28, 2023
2 parents 142bef2 + c599685 commit 7e34aef
Show file tree
Hide file tree
Showing 44 changed files with 3,909 additions and 685 deletions.
112 changes: 112 additions & 0 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/flash-attention.git"
GIT_TAG 0598fa245bbfb8c4462002600864518c0e37e714
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
include_directories(${FLASHATTN_INCLUDE_DIR})

add_library(flashattn INTERFACE)
#set_property(TARGET flashattn PROPERTY IMPORTED_LOCATION ${FLASHATTN_LIBRARIES})
add_dependencies(flashattn extern_flashattn)
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ if(WITH_GPU
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
message(STATUS "add flashattn lib")
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif()
endif()

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
71 changes: 1 addition & 70 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,76 +489,7 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
if (platform::is_gpu_place(place_)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(place_, max_memory_size));
} else {
gc.reset(new DefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
#endif
} else if (platform::is_cpu_place(place_)) {
gc.reset(new CPUGarbageCollector(place_, max_memory_size));
} else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
gc.reset(new XPUGarbageCollector(place_, max_memory_size));
#else
PADDLE_THROW(
platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
#endif
} else if (platform::is_ipu_place(place_)) {
#ifdef PADDLE_WITH_IPU
gc.reset(new IPUGarbageCollector(place_, max_memory_size));
#else
PADDLE_THROW(
platform::errors::Unimplemented("No IPU gc found in CPU/IPU paddle"));
#endif
} else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for NPU.";
gc.reset(new NPUUnsafeFastGarbageCollector(place_, max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
#endif
} else if (platform::is_mlu_place(place_)) {
#ifdef PADDLE_WITH_MLU
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new MLUUnsafeFastGarbageCollector(place_, max_memory_size));
} else {
gc.reset(new MLUDefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
#endif
} else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
max_memory_size));
} else {
VLOG(4) << "Use default stream gc for " << place_ << ".";
gc.reset(
new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
#endif
}
gc = CreateGarbageCollector(place_, max_memory_size);
}

for (int64_t i = start_op_index; i < end_op_index; ++i) {
Expand Down
77 changes: 76 additions & 1 deletion paddle/fluid/framework/garbage_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,81 @@ void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode) {
double GetEagerDeletionMemoryFraction() {
return FLAGS_memory_fraction_of_eager_deletion;
}

// create garbage collector
std::unique_ptr<GarbageCollector> CreateGarbageCollector(
const platform::Place &place, const size_t max_memory_size) {
std::unique_ptr<GarbageCollector> gc = nullptr;
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(place, max_memory_size));
} else {
gc.reset(new DefaultStreamGarbageCollector(place, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No GPU gc found in CPU/XPU paddle"));
#endif
} else if (platform::is_cpu_place(place)) {
gc.reset(new CPUGarbageCollector(place, max_memory_size));
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
gc.reset(new XPUGarbageCollector(place, max_memory_size));
#else
PADDLE_THROW(
platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
#endif
} else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
gc.reset(new IPUGarbageCollector(place, max_memory_size));
#else
PADDLE_THROW(
platform::errors::Unimplemented("No IPU gc found in CPU/IPU paddle"));
#endif
} else if (platform::is_npu_place(place)) {
#ifdef PADDLE_WITH_ASCEND_CL
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for NPU.";
gc.reset(new NPUUnsafeFastGarbageCollector(place, max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(place, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
#endif
} else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_MLU
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new MLUUnsafeFastGarbageCollector(place, max_memory_size));
} else {
gc.reset(new MLUDefaultStreamGarbageCollector(place, max_memory_size));
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for " << place << ".";
gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place,
max_memory_size));
} else {
VLOG(4) << "Use default stream gc for " << place << ".";
gc.reset(
new CustomDefaultStreamGarbageCollector(place, max_memory_size));
}
#else
PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
#endif
}
return std::unique_ptr<GarbageCollector>(gc.release());
}
} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/framework/garbage_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ bool IsFastEagerDeletionModeEnabled();
void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode);

double GetEagerDeletionMemoryFraction();

// create
extern std::unique_ptr<GarbageCollector> CreateGarbageCollector(
const platform::Place &place, const size_t max_memory_size);
} // namespace framework
} // namespace paddle
49 changes: 47 additions & 2 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
#endif
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/string/string_helper.h"
PADDLE_DEFINE_EXPORTED_bool(enable_opt_infer_gc_var,
false,
"enable opt infer gc var");

namespace paddle {
namespace framework {
Expand All @@ -37,9 +42,20 @@ void NaiveExecutor::Prepare(Scope *scope,
} else {
scope_ = scope;
}
root_scope_ = scope;
while (root_scope_->parent()) {
root_scope_ = root_scope_->parent();
}

gc_ = nullptr;
int64_t max_memory_size = GetEagerDeletionThreshold();
if (FLAGS_enable_opt_infer_gc_var && max_memory_size >= 0) {
auto gc = CreateGarbageCollector(place_, max_memory_size);
gc_ = gc.release();
}

VLOG(3) << "NaiveExecutor init with scope " << scope;
CreateOps(program_desc, block_id, with_feed_fetch_ops);
VLOG(3) << "NaiveExecutor init with scope " << scope;
}

void NaiveExecutor::Run() {
Expand Down Expand Up @@ -104,7 +120,21 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc,
void NaiveExecutor::CreateOps(const ProgramDesc &desc,
int block_id,
bool with_feed_fetch_ops) {
for (const auto &op_desc : desc.Block(block_id).AllOps()) {
auto &global_block = desc.Block(block_id);
// create op
auto ops_desc = global_block.AllOps();
for (const auto &op_desc : ops_desc) {
// gc var
if (gc_) {
if (op_desc->Type() == "feed" || op_desc->Type() == "fetch") {
for (auto &o : op_desc->Inputs()) {
skip_vars_.insert(skip_vars_.end(), o.second.begin(), o.second.end());
}
for (auto &o : op_desc->Outputs()) {
skip_vars_.insert(skip_vars_.end(), o.second.begin(), o.second.end());
}
}
}
if (!with_feed_fetch_ops &&
(op_desc->Type() == "feed" || op_desc->Type() == "fetch")) {
LOG(INFO) << "--- skip [" << op_desc->Input("X")[0] << "], "
Expand All @@ -113,6 +143,11 @@ void NaiveExecutor::CreateOps(const ProgramDesc &desc,
}
ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
}
if (!gc_) {
return;
}
// get used
unused_vars_ = GetUnusedVars(global_block, ops_, skip_vars_);
}

LoDTensor *NaiveExecutor::FindTensor(const std::string &name) {
Expand All @@ -136,13 +171,23 @@ void NaiveExecutor::CleanFeedFetchOps() {
}
ops_.swap(ops);
}
void NaiveExecutor::AddSkipVars(const std::vector<std::string> &skip_vars) {
if (skip_vars.empty()) {
return;
}
skip_vars_.insert(skip_vars_.end(), skip_vars.begin(), skip_vars.end());
}

NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
platform::ClearMKLDNNCache(place_, this);
#endif
if (gc_) {
delete gc_;
gc_ = nullptr;
}
}

void NaiveExecutor::ResetTrtOps(int num) {
Expand Down
Loading

0 comments on commit 7e34aef

Please sign in to comment.