Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/apply_patches.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

CDIR=$(dirname $0)
CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
XDIR=$CDIR/..
PTDIR=$XDIR/..

Expand Down
86 changes: 62 additions & 24 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,25 @@
import re
import sys

FuncGen = collections.namedtuple(

def namedtuple_with_defaults(typename, field_names, default_values=()):
ntuple = collections.namedtuple(typename, field_names)
ntuple.__new__.__defaults__ = (None,) * len(ntuple._fields)
if isinstance(default_values, collections.Mapping):
prototype = ntuple(**default_values)
else:
prototype = ntuple(*default_values)
ntuple.__new__.__defaults__ = tuple(prototype)
return ntuple


FuncGen = namedtuple_with_defaults(
'FuncGen',
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig'
)

FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param')

_GRAMMAR = r"""
start: type fnname "(" params ")"
type: CONST? core_type refspec?
Expand Down Expand Up @@ -68,11 +82,16 @@
'unsafeStorageFromTH',
'unsafeTensorFromTH',
# XLA/TPU functions
'_s_copy_from',
'ones',
'ones_like',
'zeros',
'zeros_like',
])

_FN_BLACKLIST_REGEX = [
# ATEN functions
r'.*cudnn',
r'[^(]*cudnn',
# XLA/TPU functions
]

Expand Down Expand Up @@ -233,11 +252,16 @@ class {type_name} : public AtenXlaType {{
'empty': '.device(at::DeviceType::CPU)',
'linspace': '.device(at::DeviceType::CPU)',
'logspace': '.device(at::DeviceType::CPU)',
'ones': '.device(at::DeviceType::CPU)',
'ones_like': '.device(at::DeviceType::CPU)',
'randn': '.device(at::DeviceType::CPU)',
'zeros': '.device(at::DeviceType::CPU)',
'zeros_like': '.device(at::DeviceType::CPU)',
}

_FUNCTION_OPTIONS = {
'to(Tensor, TensorOptions, bool, bool) -> Tensor':
FuncOpts(ref_param='options'),
'to(Tensor, Device, ScalarType, bool, bool) -> Tensor':
FuncOpts(ref_param='device'),
'to(Tensor, Tensor, bool, bool) -> Tensor':
FuncOpts(ref_param='other'),
}

_RESULT_NAME = 'x_result'
Expand Down Expand Up @@ -317,11 +341,11 @@ def list_get(l, n):
return l[n] if n < len(l) else None


def is_blacklisted_fn(fname):
if fname in _FN_BLACKLIST:
def is_blacklisted_fn(fname, mapsig):
if fname in _FN_BLACKLIST or mapsig in _FN_BLACKLIST:
return True
for frx in _FN_BLACKLIST_REGEX:
if re.match(frx, fname):
if re.match(frx, fname) or re.match(frx, mapsig):
return True
return False

Expand Down Expand Up @@ -580,11 +604,11 @@ def get_return_value(rtype, rname, param, var, ref_param):
# If instead the return type is a value Tensor, we create a new one by
# wrapping the proper local variable which has been created by calling
# into the CPU tensor implementation.
return 'bridge::CreateXlaTensor({}, bridge::XlaTensorDevice({}))'.format(
return 'bridge::CreateXlaTensor({}, bridge::GetXlaDevice({}))'.format(
rname, param_name(ref_param))


def get_reference_param(params):
def get_reference_param(params, fnopts=None):
# The reference parameter is the Tensor object which we use to extract the
# result Tensor device, if any.
ref_param = None
Expand All @@ -593,13 +617,15 @@ def get_reference_param(params):
ptype = param_type(p)
cptype = type_core(ptype)
pname = param_name(p)
if cptype == 'TensorOptions' or cptype == 'TensorList':
if fnopts and fnopts.ref_param == pname:
return p
if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'):
other = p
if cptype != 'Tensor':
continue
if pname == 'self' or type_is_const(ptype):
return p
ref_param = p
if not ref_param and (pname == 'self' or type_is_const(ptype)):
ref_param = p
other = p
return ref_param or other


Expand Down Expand Up @@ -642,7 +668,7 @@ def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars,
retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars,
ref_param)
elif ctype == 'std::vector':
retstr = 'bridge::CreateXlaTensors({}, bridge::XlaTensorDevice({}))'.format(
retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format(
rname, param_name(ref_param))
elif ctype == 'Tensor':
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param)
Expand Down Expand Up @@ -679,10 +705,12 @@ def rewrite_tensor_options(fname, pname):
def get_xla_wrapper(orig_sig, ctx):
tree = _PARSER.parse(orig_sig)
xtree = _XPARSER.parse(orig_sig)
mapsig = create_map_sig(xtree, orig_sig)
rwsig = rewrite_signature(orig_sig, _TYPE_NSMAP)
rwxtree = _XPARSER.parse(rwsig)
params = get_parameters(tree)
ref_param = get_reference_param(params)
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)
ref_param = get_reference_param(params, fnopts=fnopts)

# There are a few functions with the same function name but different
# parameter list. Generate a unique XL function name here.
Expand All @@ -698,6 +726,9 @@ def gen_fnname(x):
return 'xla_' + x + post

sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
if is_blacklisted_fn(fname, mapsig):
return None

code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '')
xla_ref_param = param_name(ref_param) if ref_param else None
tfetcher = TensorFetcher('xlatens')
Expand All @@ -723,7 +754,7 @@ def gen_fnname(x):
else:
xname = tfetcher.add(pname, True)
param_vars.append(xname)
if p == ref_param:
if p == ref_param and not (fnopts and fnopts.ref_param):
xla_ref_param = param_vars[-1]
code += tfetcher.generate()
result_assign = generate_result_assignment(tree, _RESULT_NAME)
Expand Down Expand Up @@ -755,7 +786,7 @@ def gen_fnname(x):
rwsig=rwsig,
cppsig=sig,
funsig=create_stdfunc_sig(rwxtree, rwsig),
mapsig=create_map_sig(xtree, orig_sig))
mapsig=mapsig)


def extract_functions(path):
Expand All @@ -766,10 +797,8 @@ def extract_functions(path):
continue
fndef = m.group(1)
try:
tree = _PARSER.parse(fndef)
fname = get_function_name(tree)
if not is_blacklisted_fn(fname):
functions.append(fndef)
_XPARSER.parse(fndef)
functions.append(fndef)
except:
pass
return functions
Expand Down Expand Up @@ -823,7 +852,16 @@ def generate(args):
fgens = []
ctx = Context(args.functions, args.native_functions, args.gen_class_mode)
for ts in fndefs:
fgens.append(get_xla_wrapper(ts, ctx))
try:
fgen = get_xla_wrapper(ts, ctx)
if fgen:
fgens.append(fgen)
except Exception as e:
print(
'File to generate wrapper for {}: {}'.format(ts, e), file=sys.stderr)
print(
'Generated {} wrappers for {}'.format(len(fgens), args.typedef),
file=sys.stderr)

functions = generate_functions(fgens)
if args.gen_class_mode:
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_code.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

CDIR=$(dirname $0)
CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
XDIR=$CDIR/..
PTDIR=$XDIR/..

Expand Down
17 changes: 7 additions & 10 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ target_compile_options(test_ptxla PRIVATE ${TGT_OPTS})

add_dependencies(test_ptxla googletest)

if(MSVC)
set(SUFFIX ".lib")
else()
set(SUFFIX ".a")
set(PTHREAD "-pthread")
endif()

ExternalProject_Get_Property(googletest BINARY_DIR)

file(GLOB XLAC_LIBS "${PTXLA_LIBDIR}/_XLAC.*.so")
Expand All @@ -103,15 +96,19 @@ find_library(C10_LIB "libc10.so"
find_library(CAFFE_LIB "libcaffe2.so"
HINTS "${PT_DIR}/build/lib")

# Use --unresolved-symbols=ignore-all to get around the c10::Half::from_bits
# undefined symbol error at link time. At runtime everything resolves correctly.
target_link_libraries(
test_ptxla
-Wl,--unresolved-symbols=ignore-in-shared-libs
"${PTXLA_LIB}"
"${PTXLA_LIBDIR}/torch_xla/lib/libxla_computation_client.so"
"${PTPY_LIB}"
"${PT_LIB}"
"${C10_LIB}"
"${CAFFE_LIB}"
"${BINARY_DIR}/lib/${CMAKE_FIND_LIBRARY_PREFIXES}gtest${SUFFIX}"
"${C10_LIB}"
"${BINARY_DIR}/lib/${CMAKE_FIND_LIBRARY_PREFIXES}gtest.a"
"${PYTHON_LIBRARY}"
${PTHREAD}
-pthread
-lstdc++
-ldl)
12 changes: 7 additions & 5 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/bin/bash
set -ex
RUNDIR="$(cd "$(dirname "$0")" ; pwd -P)"
BUILDDIR="$RUNDIR/build"
VERB=
RMBUILD=1
LOGFILE=/tmp/pytorch_cpp_test.log
Expand All @@ -20,10 +22,10 @@ do
done
shift $(($OPTIND - 1))

rm -rf build
mkdir build 2>/dev/null
pushd build
cmake .. \
rm -rf "$BUILDDIR"
mkdir "$BUILDDIR" 2>/dev/null
pushd "$BUILDDIR"
cmake "$RUNDIR" \
-DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")\
-DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))")
make $VERB
Expand All @@ -34,5 +36,5 @@ else
fi
popd
if [ $RMBUILD -eq 1 ]; then
rm -rf build
rm -rf "$BUILDDIR"
fi
7 changes: 4 additions & 3 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,9 +1359,10 @@ def test_log_softmax(self):

class TestAtenXlaTensor(XlaTestCase):

def test_size(self):
x = _gen_tensor(4, 2, device=_xla_device())
torch_xla._XLAC._get_xla_tensor(x)
def test_get_xla_tensor(self):
t = _gen_tensor(4, 2, device=_xla_device())
x = torch_xla._XLAC._get_xla_tensor(t)
self.assertEqual(t.data.cpu(), x.to_tensor())


if __name__ == '__main__':
Expand Down
94 changes: 94 additions & 0 deletions torch_patches/16913.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
commit 84256137efc10f794b31575a4465838982cd6426
Author: Davide Libenzi <dlibenzi@google.com>
Date: Fri Feb 8 20:07:05 2019 -0800

Allow the variable type associations to depend on backend+scalar_type, like the Type does.

diff --git a/aten/src/ATen/templates/TypeExtension.cpp b/aten/src/ATen/templates/TypeExtension.cpp
index 313b865fc..de7c0175b 100644
--- a/aten/src/ATen/templates/TypeExtension.cpp
+++ b/aten/src/ATen/templates/TypeExtension.cpp
@@ -23,7 +23,7 @@ std::unique_ptr<Generator> ${Type}::generator() const {
}

ScalarType ${Type}::scalarType() const {
- AT_ERROR("scalarType is not implemented for ${Type}");
+ return ScalarType::Float;
}

caffe2::TypeMeta ${Type}::typeMeta() const {
diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp
index 26c2b8b36..1d5fadc87 100644
--- a/torch/csrc/autograd/VariableTypeManual.cpp
+++ b/torch/csrc/autograd/VariableTypeManual.cpp
@@ -63,16 +63,16 @@ TypeID VariableType::ID() const {
return static_cast<TypeID>(id_);
}

-std::vector<std::unique_ptr<Type>> type_to_variable_type;
+static std::unique_ptr<Type> type_to_variable_type
+ [static_cast<int>(Backend::NumOptions)]
+ [static_cast<int>(ScalarType::NumOptions)];

// XXX - this is not threadsafe with uses of Variables
-void register_variable_type_for(TypeExtendedInterface* baseType) {
+void register_variable_type_for(int backend, int scalar, TypeExtendedInterface* baseType) {
AT_ASSERT(baseType);
- const auto base_id = static_cast<size_t>(baseType->ID());
- if(type_to_variable_type.size() <= base_id) {
- type_to_variable_type.resize(base_id + 1);
- }
- type_to_variable_type[base_id] =
+ AT_ASSERT(backend >= 0 && backend < static_cast<int>(Backend::NumOptions));
+ AT_ASSERT(scalar >= 0 && scalar < static_cast<int>(ScalarType::NumOptions));
+ type_to_variable_type[backend][scalar] =
make_unique<VariableType>(&at::globalContext(), baseType);
}

@@ -83,7 +83,7 @@ struct VariableTypeRegistry {
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
auto baseType = context.getNonVariableTypeRaw(static_cast<Backend>(p), static_cast<ScalarType>(s));
if (baseType && baseType->backend() != Backend::Undefined) {
- register_variable_type_for(baseType);
+ register_variable_type_for(p, s, baseType);
}
}
}
@@ -128,7 +128,9 @@ REGISTER_VARIABLE_HOOKS(VariableHooks)
// Pre-condition: backend/scalar_type is a valid type in the type_registry
void VariableHooks::registerVariableTypeFor(at::LegacyTypeDispatch* context, at::Backend backend, at::ScalarType scalar_type) const {
auto* baseType = context->getNonVariableTypeRaw(backend, scalar_type);
- register_variable_type_for(static_cast<at::TypeExtendedInterface*>(baseType));
+ register_variable_type_for(static_cast<int>(backend),
+ static_cast<int>(scalar_type),
+ static_cast<at::TypeExtendedInterface*>(baseType));
}

at::Type& VariableHooks::getVariableTypeFromBaseType(const at::Type& baseType) const {
@@ -140,10 +142,9 @@ bool VariableType::isVariableType(const at::Type& type) {
}

at::TypeExtendedInterface* VariableType::getVariableTypeFromBaseType(const at::Type& baseType) {
- auto id = static_cast<size_t>(baseType.ID());
- if(id >= type_to_variable_type.size())
- return nullptr;
- return static_cast<at::TypeExtendedInterface*>(type_to_variable_type[id].get());
+ return static_cast<at::TypeExtendedInterface*>(
+ type_to_variable_type[static_cast<int>(baseType.backend())]
+ [static_cast<int>(baseType.scalarType())].get());
}

namespace {
diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h
index b37c36d40..6a0051700 100644
--- a/torch/csrc/autograd/VariableTypeUtils.h
+++ b/torch/csrc/autograd/VariableTypeUtils.h
@@ -40,8 +40,6 @@ using namespace torch::autograd::generated;

namespace torch { namespace autograd {

-extern std::vector<std::unique_ptr<Type>> type_to_variable_type;
-
inline void check_inplace(const Tensor& tensor) {
auto& var = static_cast<const Variable&>(tensor);
if (var.requires_grad() && var.is_leaf() && GradMode::is_enabled()) {
Loading