Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ c10::optional<at::Tensor> runTorchSlice_opset10(
at::Tensor runTorchArange_opset11(
const Node* node,
const std::vector<at::Tensor>& inputTensorValues) {
AT_ASSERT(inputTensorValues.size() == 3);
TORCH_INTERNAL_ASSERT(inputTensorValues.size() == 3);
auto dtype = inputTensorValues[0].scalar_type();
at::Tensor updated_val;
switch (dtype) {
Expand Down Expand Up @@ -575,7 +575,7 @@ std::vector<at::Tensor> getValues(
"getValues: Unsupported kind of constant node found.");
}
}
AT_ASSERT(inputTensorValues.size() == numInputs);
TORCH_INTERNAL_ASSERT(inputTensorValues.size() == numInputs);
return inputTensorValues;
}

Expand Down Expand Up @@ -618,7 +618,7 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
"Constant folding not applied.");
return;
}
AT_ASSERT(b->param_node());
TORCH_INTERNAL_ASSERT(b->param_node());
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
// Only the root block is constant-folded. Folding nested blocks is
// not supported for now.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/function_substitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void functionCallSubstitution(Block* block) {
Node* cur = *it++;
switch (cur->kind()) {
case prim::CallFunction: {
AT_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
TORCH_INTERNAL_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
cur_dim++;
}

AT_ASSERT(cur_dim == dim);
TORCH_INTERNAL_ASSERT(cur_dim == dim);
if (node->kind() == aten::slice) {
auto size = CreateSizeOfDim(orig_data, dim, index_put_node);
auto index_tensor = ConvertSliceToIndex(node, size, index_put_node);
Expand Down Expand Up @@ -165,7 +165,7 @@ std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
}

// Each dimension should have its associated index tensor.
AT_ASSERT((int64_t)dim_index_map.size() == cur_dim);
TORCH_INTERNAL_ASSERT((int64_t)dim_index_map.size() == cur_dim);
return dim_index_map;
}

Expand All @@ -190,7 +190,7 @@ std::vector<Value*> ReshapeToAdvancedIndexingFormat(
size_t tensor_ind_count = 0;
for (const auto i : c10::irange(dim_index_map.size())) {
auto index_i = dim_index_map.find(i);
AT_ASSERT(index_i != dim_index_map.end());
TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
if (index_i->second.orig_node_kind == aten::index) {
if (i < min_index_dim)
min_index_dim = i;
Expand All @@ -212,7 +212,7 @@ std::vector<Value*> ReshapeToAdvancedIndexingFormat(
for (const auto i : c10::irange(dim_index_map.size())) {
size_t ind_size = 0;
auto index_i = dim_index_map.find(i);
AT_ASSERT(index_i != dim_index_map.end());
TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
Value* index = index_i->second.index;
switch (index_i->second.orig_node_kind) {
case aten::select:
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ bool isNopTranspose(const std::vector<int64_t>& perm) {
std::vector<int64_t> composeTransposes(
const std::vector<int64_t>& t1,
const std::vector<int64_t>& t2) {
AT_ASSERT(t1.size() == t2.size());
TORCH_INTERNAL_ASSERT(t1.size() == t2.size());
std::vector<int64_t> ret;
ret.reserve(t1.size());
for (const auto& i : t2) {
AT_ASSERT(i < int64_t(t1.size()));
TORCH_INTERNAL_ASSERT(i < int64_t(t1.size()));
ret.push_back(t1[i]);
}
return ret;
Expand Down Expand Up @@ -131,7 +131,7 @@ void fuseBroadcast(Block* b) {

auto broadcast_positions = getBroadcastPositions(n);
if (!broadcast_positions.empty()) {
AT_ASSERT(!n->hasAttribute(attr::axis));
TORCH_INTERNAL_ASSERT(!n->hasAttribute(attr::axis));
}

for (size_t position : broadcast_positions) {
Expand Down Expand Up @@ -627,7 +627,7 @@ static void speculateOps(Block* block) {
static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
node->removeInput(i);
for (auto* to_val : to) {
AT_ASSERT(to_val->owningGraph() == node->owningGraph());
TORCH_INTERNAL_ASSERT(to_val->owningGraph() == node->owningGraph());
node->insertInput(i++, to_val);
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static bool IsComparisonOp(const NodeKind& nkind) {
static TensorTypePtr CreateProfiledTensorTypeWithScalarType(
const TensorTypePtr& typePtr,
const c10::ScalarType& scalar_type) {
AT_ASSERT(typePtr != nullptr);
TORCH_INTERNAL_ASSERT(typePtr != nullptr);
return typePtr->withScalarType({scalar_type});
}

Expand Down