Skip to content

Commit

Permalink
Merge pull request #30 from earhart/vai-main-29
Browse files Browse the repository at this point in the history
Add correct cl_khr_fp64 support
  • Loading branch information
jbruestle committed Nov 14, 2017
2 parents b663eeb + c8ab84e commit 2515700
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
7 changes: 6 additions & 1 deletion public/plaidml/tile/hal/opencl/compiler.cc
Expand Up @@ -136,6 +136,11 @@ boost::future<std::unique_ptr<hal::Library>> Compiler::Build(const context::Cont
code << "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
}

bool cl_khr_fp64 = device_state_->HasDeviceExtension("cl_khr_fp64");
if (cl_khr_fp64) {
code << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
}

for (const auto& ki : kernel_info) {
context::Activity kbuild{activity.ctx(), "tile::hal::opencl::BuildKernel"};

Expand All @@ -152,7 +157,7 @@ boost::future<std::unique_ptr<hal::Library>> Compiler::Build(const context::Cont
}

code << ki.comments;
Emit ocl{cl_khr_fp16};
Emit ocl{cl_khr_fp16, cl_khr_fp64};
ocl.Visit(*ki.kfunc);
code << ocl.str();
code << "\n\n";
Expand Down
15 changes: 15 additions & 0 deletions public/plaidml/tile/hal/opencl/emitocl.cc
Expand Up @@ -6,6 +6,7 @@
#include <map>
#include <utility>

#include "base/util/error.h"
#include "tile/hal/opencl/exprtype.h"
#include "tile/lang/fpconv.h"

Expand Down Expand Up @@ -110,6 +111,7 @@ void Emit::Visit(const sem::DeclareStmt &n) {
}
}
emit(";\n");
CheckValidType(ty);
scope_->Bind(n.name, ty);
}

Expand Down Expand Up @@ -258,6 +260,7 @@ void Emit::Visit(const sem::Function &n) {
// Global booleans are stored as INT8.
ty.dtype = lang::DataType::INT8;
}
CheckValidType(ty);
scope.Bind(p.second, ty);
}

Expand Down Expand Up @@ -294,6 +297,18 @@ void Emit::Visit(const sem::Function &n) {
scope_ = nullptr;
}

void Emit::CheckValidType(const sem::Type &ty) {
if (cl_khr_fp64_) {
return;
}
if (ty.base == sem::Type::TVOID || ty.base == sem::Type::INDEX) {
return;
}
if (ty.dtype == lang::DataType::FLOAT64) {
throw error::Unimplemented{"The device does not support 64-bit floating-point types"};
}
}

sem::Type Emit::TypeOf(const sem::ExprPtr &expr) { return ExprType::TypeOf(scope_, cl_khr_fp16_, expr); }

sem::Type Emit::TypeOf(const sem::LValPtr &lvalue) { return ExprType::TypeOf(scope_, cl_khr_fp16_, lvalue); }
Expand Down
5 changes: 4 additions & 1 deletion public/plaidml/tile/hal/opencl/emitocl.h
Expand Up @@ -15,7 +15,8 @@ namespace opencl {

class Emit : public lang::EmitC {
public:
explicit Emit(bool cl_khr_fp16) : cl_khr_fp16_{cl_khr_fp16}, scope_{nullptr} {}
explicit Emit(bool cl_khr_fp16, bool cl_khr_fp64)
: cl_khr_fp16_{cl_khr_fp16}, cl_khr_fp64_{cl_khr_fp64}, scope_{nullptr} {}

void Visit(const sem::LoadExpr &) final;
void Visit(const sem::StoreStmt &) final;
Expand All @@ -33,6 +34,7 @@ class Emit : public lang::EmitC {
void Visit(const sem::Function &) final;

private:
void CheckValidType(const sem::Type &ty);
sem::Type TypeOf(const sem::ExprPtr &expr);
sem::Type TypeOf(const sem::LValPtr &lvalue);
void EmitWithTypeConversion(const sem::Type &from, const sem::Type &to, const sem::ExprPtr &expr,
Expand All @@ -42,6 +44,7 @@ class Emit : public lang::EmitC {
void emitType(const sem::Type &t) final;

bool cl_khr_fp16_;
bool cl_khr_fp64_;
lang::Scope<sem::Type> *scope_;
};

Expand Down

0 comments on commit 2515700

Please sign in to comment.