Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aten/src/ATen/core/List.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ class List final {
*/
void push_back(T&& value) const;

/**
* Appends the given list to the end of the container. Uses at most one memory allocation.
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
*/
void append(List<T> lst) const;

/**
* Appends the given element value to the end of the container.
* The new element is constructed with the given arguments.
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/List_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ void List<T>::push_back(T&& value) const {
impl_->list.push_back(detail::list_element_from<T, StorageT>(std::move(value)));
}

template<class T>
void List<T>::append(List<T> b) const {
if (b.use_count() == 1) {
impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
} else {
impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
}
}

template<class T>
template<class... Args>
void List<T>::emplace_back(Args&&... args) const {
Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,14 @@ int listAdd(Stack& stack) {
push(stack, std::move(ret));
return 0;
}
template <class T>
int listInplaceAdd(Stack& stack) {
c10::List<T> b = pop(stack).to<List<T>>();
c10::List<T> a = pop(stack).to<List<T>>();
a.append(std::move(b));
push(stack, std::move(a));
return 0;
}

template <class T>
int listMulIntLeft(Stack& stack) {
Expand Down Expand Up @@ -1915,6 +1923,10 @@ RegisterOperators reg2({
"aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
"[]", \
listAdd<c_type::value_type>), \
Operator( \
"aten::add_(" decl_type "[](a!) self, " decl_type "[] b) -> " decl_type \
"[]", \
listInplaceAdd<c_type::value_type>), \
Operator( \
"aten::slice(" decl_type \
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
Expand Down Expand Up @@ -2068,7 +2080,7 @@ RegisterOperators reg2({
"prim::min(int[] x) -> int",
[](Stack& stack) {
c10::List<int64_t> int_list = pop(stack).toIntList();
int64_t min_element = std::numeric_limits<int64_t>::max();
int64_t min_element = std::numeric_limits<int64_t>::max();

for(int64_t ele: int_list) {
if(ele < min_element) {
Expand Down
25 changes: 16 additions & 9 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,14 @@ struct to_ir {
// Get the appropriate builtin op for this augmented assignment
// If the RHS is a tensor, return the corresponding ATen in-place op
// If it's a list of scalars, then return the corresponding list augment op
Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you're doing here, this seems unrelated. Can we have this in a separate PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what's needed to make a += b do the inplace add instead of desugaring it into a = a + b.

Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) {
if (type->cast<ListType>()) { // Lists also have in-place ops.
switch (stmt.aug_op()) {
case '+':
return aten::add_;
}
}
bool isTensor = type->isSubtypeOf(TensorType::get());
switch (stmt.aug_op()) {
case '+':
return isTensor ? aten::add_ : aten::add;
Expand Down Expand Up @@ -1524,7 +1531,7 @@ struct to_ir {
emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
getAugOp(stmt, lhsValue->type()),
self,
{rhs},
{},
Expand All @@ -1541,14 +1548,15 @@ struct to_ir {
const auto lhs = Var(stmt.lhs());
const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
->asValue(lhs.range(), method);
if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
auto lhsType = lhsValue->type();
if (lhsType->isSubtypeOf(TensorType::get()) || lhsType->cast<c10::ListType>()) {
// for tensors, emit the corresponding in-place op
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
const auto output = emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
getAugOp(stmt, lhsValue->type()),
self,
{rhs},
{},
Expand Down Expand Up @@ -1589,7 +1597,7 @@ struct to_ir {
emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
getAugOp(stmt, sliceable->type()),
slicedArg,
{rhs},
{},
Expand All @@ -1606,7 +1614,7 @@ struct to_ir {
const auto augmented = emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
getAugOp(stmt, sliceable->type()),
indexed,
{rhs},
{},
Expand All @@ -1624,8 +1632,7 @@ struct to_ir {
const auto listType = sliceable->type()->cast<ListType>();
AT_ASSERT(listType != nullptr);

bool isTensorList =
listType->getElementType()->isSubtypeOf(TensorType::get());
auto elementType = listType->getElementType();

// Get the idx to augment
const auto subscriptExprs = lhs.subscript_exprs();
Expand All @@ -1645,7 +1652,7 @@ struct to_ir {
const auto getItem =
graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
const auto augmentedItem = graph->insert(
getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
getAugOp(stmt, elementType), {getItem, valueArg}, {}, stmt.range());
graph->insert(
aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
}
Expand Down