Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions torch_tvm/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
auto optional_ivalue = toIValue(val);
if (optional_ivalue.has_value()) {
if (optional_ivalue.value().isTensor()) {
auto t = optional_ivalue.value().toTensor();
val->inferTypeFrom(optional_ivalue.value().toTensor());
} else {
auto expr = convertToRelay(optional_ivalue.value(), ctx)
Expand All @@ -45,15 +46,23 @@ tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
if (val->isCompleteTensor()) {
// Ensure if complete tensor has device type then it is CPU
// otherwise it is assume to be CPU.
auto pt_t = val->type()->cast<CompleteTensorType>();
auto device_type = pt_t->device();
auto pt_t = val->type()->cast<ProfiledTensorType>();
TORCH_INTERNAL_ASSERT(pt_t);
auto optional_device_type = pt_t->device();
TORCH_INTERNAL_ASSERT(optional_device_type);
auto device_type = optional_device_type.value();
AT_CHECK(device_type == at::DeviceType::CPU,
"Expected CPU device type but got:", device_type);
tvm::Array<tvm::relay::IndexExpr> sizes;
for (const auto& size : pt_t->sizes()) {
sizes.push_back(tvm::relay::IndexExpr(static_cast<int32_t>(size)));
const auto& varying_sizes = pt_t->sizes();
for (const auto& optional_size : varying_sizes.sizes()) {
TORCH_INTERNAL_ASSERT(optional_size);
sizes.push_back(tvm::relay::IndexExpr(
static_cast<int32_t>(optional_size.value())));
}
at::ScalarType pt_type = pt_t->scalarType();
auto optional_dtype = pt_t->scalarType();
TORCH_INTERNAL_ASSERT(optional_dtype);
at::ScalarType pt_type = optional_dtype.value();
auto t = tvm::relay::TensorTypeNode::make(sizes, scalarTypeToTVMType(pt_type));
auto v = tvm::relay::VarNode::make(
val->debugName() +
Expand Down
6 changes: 4 additions & 2 deletions torch_tvm/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,11 @@ RegisterTVMOperator reg({
{Symbol::fromQualString("aten::linear"),
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
Value* input = node->input(0);
auto d_tensor = input->type()->cast<DimensionedTensorType>();
auto d_tensor = input->type()->cast<ProfiledTensorType>();
if (d_tensor) {
int64_t n_dim = d_tensor->dim();
auto optional_n_dim = d_tensor->dim();
TORCH_INTERNAL_ASSERT(optional_n_dim);
int64_t n_dim = optional_n_dim.value();
TORCH_CHECK(n_dim == 2,
"WARNING: relay does not support dense operation on inputs more than 2 dim");
}
Expand Down