Skip to content

Commit

Permalink
Update on "Construct CppSignatureGroup from NativeFunction"
Browse files Browse the repository at this point in the history
This will make it easier to implement the POC in
peterbell10@d534f7d
see also #45666

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D25594005](https://our.internmc.facebook.com/intern/diff/D25594005)

[ghstack-poisoned]
  • Loading branch information
ezyang committed Dec 17, 2020
2 parents 7efca0b + 5c132e3 commit 69d981a
Show file tree
Hide file tree
Showing 130 changed files with 2,254 additions and 700 deletions.
8 changes: 4 additions & 4 deletions .circleci/scripts/binary_linux_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ set -eux -o pipefail
python_nodot="\$(echo $DESIRED_PYTHON | tr -d m.u)"
# There was a bug that was introduced in conda-package-handling >= 1.6.1 that makes archives
# above a certain size fail out when attempting to extract
# see: https://github.com/conda/conda-package-handling/issues/71
conda install -y conda-package-handling=1.6.0
# Set up Python
if [[ "$PACKAGE_TYPE" == conda ]]; then
# There was a bug that was introduced in conda-package-handling >= 1.6.1 that makes archives
# above a certain size fail out when attempting to extract
# see: https://github.com/conda/conda-package-handling/issues/71
conda install -y conda-package-handling=1.6.0
retry conda create -qyn testenv python="$DESIRED_PYTHON"
source activate testenv >/dev/null
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/ParallelOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <mkl.h>
#endif

#include <caffe2/utils/threadpool/pthreadpool-cpp.h>

namespace at {

namespace {
Expand Down Expand Up @@ -49,6 +51,12 @@ void set_num_threads(int nthreads) {
// See https://github.com/pytorch/pytorch/issues/13757
mkl_set_dynamic(false);
#endif
#ifdef USE_PTHREADPOOL
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif
}

// Explicitly calling omp_get_max_threads() as the size of the parallel
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// The macro doesn't like this one (I think it chokes on commas inside <>) so write it manually
m.impl(TORCH_SELECTIVE_NAME("aten::native_layer_norm"),
TORCH_FN((&WrapFunction<CastPolicy::fp32,
std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double),
std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double),
std::tuple<Tensor,Tensor,Tensor> (const Tensor&, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double),
std::tuple<Tensor,Tensor,Tensor> (const Tensor&, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double),
&ADD_NS(native_layer_norm)>::type::call)));
KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,6 @@ _(aten, signbit) \
_(aten, silu) \
_(aten, sgn) \
_(aten, sin) \
_(aten, sinc) \
_(aten, sinh) \
_(aten, size) \
_(aten, sizes) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/builtin_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct BuiltinOpFunction : public Function {

std::string pretty_print_schema() const override {
TORCH_INTERNAL_ASSERT(false);
return "";
return ""; // TODO: suppress unreachable code warning
}

Function& setSchema(c10::FunctionSchema schema) override {
Expand Down
84 changes: 44 additions & 40 deletions aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#pragma once

#include <c10/util/StringUtil.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/operator_name.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <c10/util/StringUtil.h>
#include <unordered_map>

namespace c10 {
Expand All @@ -33,8 +33,7 @@ struct Argument {
N_(std::move(N)),
default_value_(std::move(default_value)),
kwarg_only_(kwarg_only),
alias_info_(std::move(alias_info)) {
}
alias_info_(std::move(alias_info)) {}
const std::string& name() const {
return name_;
}
Expand Down Expand Up @@ -85,7 +84,8 @@ struct Argument {
}

Argument cloneWithType(TypePtr new_type) const {
return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
return Argument(
name_, std::move(new_type), N_, default_value_, kwarg_only_, alias_info_);
}

// this function check whether this Argument is backward compatible with
Expand All @@ -95,9 +95,9 @@ struct Argument {
// 3) this arg must provide the same default value if old arg has one,
bool isBackwardCompatibleWith(
const Argument& old,
std::ostream* why_not=nullptr) const;
std::ostream* why_not = nullptr) const;

private:
private:
std::string name_;
TypePtr type_;
// for list types, an optional statically known length for the list
Expand All @@ -113,12 +113,10 @@ struct Argument {
};

inline bool operator==(const Argument& lhs, const Argument& rhs) {
return lhs.name() == rhs.name()
&& *lhs.type() == *rhs.type()
&& lhs.N() == rhs.N()
&& lhs.default_value() == rhs.default_value()
&& lhs.kwarg_only() == rhs.kwarg_only()
&& lhs.alias_info() == rhs.alias_info();
return lhs.name() == rhs.name() && *lhs.type() == *rhs.type() &&
lhs.N() == rhs.N() && lhs.default_value() == rhs.default_value() &&
lhs.kwarg_only() == rhs.kwarg_only() &&
lhs.alias_info() == rhs.alias_info();
}

bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
Expand Down Expand Up @@ -200,7 +198,10 @@ struct FunctionSchema {
// this should always be set no matter what
c10::optional<AliasAnalysisKind> alias_kind_;

void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
void checkArg(
const IValue& value,
const Argument& argument,
optional<size_t> pos) const;

void checkSchema() const {
bool seen_default_arg = false;
Expand All @@ -223,8 +224,7 @@ struct FunctionSchema {
}
}

public:

public:
void dump() const;

const OperatorName& operator_name() const {
Expand Down Expand Up @@ -257,21 +257,22 @@ struct FunctionSchema {
}

c10::optional<int> argumentIndexWithName(const std::string& name) const {
for(size_t i = 0; i < arguments().size(); ++i) {
if(name == arguments()[i].name())
for (size_t i = 0; i < arguments().size(); ++i) {
if (name == arguments()[i].name()) {
return i;
}
}
return c10::nullopt;
}
FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
FunctionSchema cloneWithName(std::string name, std::string overload_name)
const {
return FunctionSchema(
std::move(name),
std::move(overload_name),
arguments(),
returns(),
is_vararg(),
is_varret()
);
std::move(name),
std::move(overload_name),
arguments(),
returns(),
is_vararg(),
is_varret());
}
FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
return FunctionSchema(
Expand Down Expand Up @@ -305,7 +306,8 @@ struct FunctionSchema {
// values.
void checkAndNormalizeInputs(
std::vector<IValue>& inputs,
const std::unordered_map<std::string, IValue>& kwargs) const;
const std::unordered_map<std::string, IValue>& kwargs =
std::unordered_map<std::string, IValue>{}) const;

std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;

Expand All @@ -323,7 +325,6 @@ struct FunctionSchema {
return false;
}


// TODO remove the mutation here
bool isDefaultAliasAnalysisKind() const {
return !alias_kind_;
Expand All @@ -349,16 +350,17 @@ struct FunctionSchema {
// schema and have the program typecheck?
// as_method - if true, treat this schema as a method and ignore
// the first argument, which will be the object in both cases
bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
bool isSubtypeOf(
const FunctionSchema& rhs,
bool as_method,
std::ostream* why_not = nullptr) const;
};

inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
return lhs.name() == rhs.name()
&& lhs.overload_name() == rhs.overload_name()
&& lhs.arguments() == rhs.arguments()
&& lhs.returns() == rhs.returns()
&& lhs.is_vararg() == rhs.is_vararg()
&& lhs.is_varret() == rhs.is_varret();
return lhs.name() == rhs.name() &&
lhs.overload_name() == rhs.overload_name() &&
lhs.arguments() == rhs.arguments() && lhs.returns() == rhs.returns() &&
lhs.is_vararg() == rhs.is_vararg() && lhs.is_varret() == rhs.is_varret();
}

inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
Expand All @@ -368,14 +370,14 @@ inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
// print out Argument, which is compatible with FunctionSchema parser
// full format: Type(alias)? name=default_value
inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {

// for adjusting the ? position.
// in schema, we have Tensor?(a!) input, and t(a!)?.
// however, t?(a!) doesn't work with schema parser.
// so we always use Type(alias)? format
auto type = arg.type();
bool is_opt = type->kind() == OptionalType::Kind;
auto unopt_type = is_opt ? type->cast<OptionalType>()->getElementType() : type;
auto unopt_type =
is_opt ? type->cast<OptionalType>()->getElementType() : type;

if (unopt_type->kind() == ListType::Kind && arg.N()) {
// sized lists get size N from arg, not type
Expand Down Expand Up @@ -409,7 +411,9 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
return out;
}

inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
inline std::ostream& operator<<(
std::ostream& out,
const FunctionSchema& schema);

inline std::string toString(const FunctionSchema& schema) {
std::ostringstream str;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ namespace c10 {
_(aten, lt_) \
_(aten, less) \
_(aten, less_) \
_(aten, isnan) \
_(aten, mul) \
_(aten, mul_) \
_(aten, multiply) \
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ TypePtr IValue::type() const {
return NoneType::get();
case Tag::Tensor:
return TensorType::create(toTensor());
case Tag::Storage:
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::Int:
Expand Down Expand Up @@ -260,6 +262,8 @@ IValue IValue::equals(const IValue& rhs) const {
return false;
}
return lhs.toTensor().eq(rhs.toTensor());
case Tag::Storage:
return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
case Tag::Double:
return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
case Tag::Int:
Expand Down Expand Up @@ -310,6 +314,8 @@ size_t IValue::hash(const IValue& v) {
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
// the tensor to emulate it
return c10::get_hash(v.payload.as_int);
case Tag::Storage:
return c10::get_hash(v.payload.as_int);
case Tag::Int:
return c10::get_hash(v.payload.as_int);
case Tag::String:
Expand Down Expand Up @@ -647,6 +653,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return out << v.toNone();
case IValue::Tag::Tensor:
return out << v.toTensor();
case IValue::Tag::Storage:
return out << v.toStorage().unsafeGetStorageImpl();
case IValue::Tag::Double: {
double d = v.toDouble();
int c = std::fpclassify(d);
Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct Capsule {
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(Storage) \
_(Double) \
_(Int) \
_(Bool) \
Expand Down Expand Up @@ -314,6 +315,20 @@ struct CAFFE2_API IValue final {
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
}

IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast<bool>(s)) {
// Note: the undefined tensor is not refcounted, so while it
// is tagged as a tensor, is_intrusive_ptr is set to false.
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined tensor.
payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl();
}
bool isStorage() const {
return Tag::Storage == tag;
}
c10::Storage toStorage() &&;
c10::Storage toStorage() const&;

const IValue& toIValue() const {
return *this;
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ inline at::Tensor IValue::toTensor() const& {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
}
inline c10::Storage IValue::toStorage() && {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
return c10::Storage(
moveToIntrusivePtr<at::StorageImpl>());
}
inline c10::Storage IValue::toStorage() const& {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
return c10::Storage(toIntrusivePtr<at::StorageImpl>());
}
inline c10::Stream IValue::toStream() && {
return c10::Stream::unpack(payload.as_int);
}
Expand Down Expand Up @@ -743,6 +752,7 @@ inline const ivalue::Object& IValue::toObjectRef() const {
return this->method_name(); \
}
DEFINE_TO(at::Tensor, toTensor)
DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
DEFINE_TO(float, toDouble)
DEFINE_TO(double, toDouble)
Expand Down

0 comments on commit 69d981a

Please sign in to comment.