Skip to content
Permalink
Browse files

[PYTHON][OPS] Added batch normalization op

  • Loading branch information
ptillet committed Oct 29, 2019
1 parent d9eacf9 commit d65a94c76843c53f7722949a493d7f77bfed814a
@@ -99,6 +99,8 @@ class function {
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);

public:
static std::string preheader();

public:
function(const std::string& src, const options_space_t& opt = options_space_t());
@@ -289,6 +289,11 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {


void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
if(!x->get_type()->is_tile_ty()){
Value *ptr = get_value(x->get_pointer_operand(), {});
set_value(x, {}, builder_->CreateLoad(ptr));
return;
}
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
@@ -229,6 +229,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* ret = ret_;
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_program_id(axis->get_value()));
else
return should_not_happen();
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
return set_ret(bld_->create_sqrt(ret));
}
return error_not_implemented();
}
@@ -274,10 +281,11 @@ void Generator::VisitDeclaration(Declaration* decl) {
// initialize declaration
ir::type::id_t id = ty->get_type_id();
if(id == ir::type::StructTyID)
assert(false);
should_not_happen();
if(inits.size() > 1)
assert(false);
val = inits[0];
should_not_happen();
if(inits.size() > 0)
val = inits[0];
assert(val->get_type() == ty);
// update scope symbols table
const std::string &name = obj->Name();
@@ -113,7 +113,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
arg arg_i = args.at(i);
arg_type ty = arg_i.type();
if(ty != param_tys_.at(i))
throw std::runtime_error("invalid type");
throw std::runtime_error("invalid type for argument " + std::to_string(i));
if(ty == BUFFER_T)
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
else
@@ -253,16 +253,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
return std::unique_ptr<driver::module>();
barriers.run(module);
// std::cout << "isel" << std::endl;
// ir::print(module, std::cout);
isel.visit(module, *llvm);
// std::cout << "done" << std::endl;
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done
return res;
}

std::string preheader() {
std::string function::preheader() {
return
R"(
#define bool _Bool
@@ -277,6 +275,7 @@ R"(
#define __multipleof(A) __attribute__((multipleof(A)))
extern int get_program_id(int);
extern float sqrtf(float);
)";
}

@@ -77,7 +77,7 @@ def build_extension(self, ext):
pass

cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
#cfg = 'Release'
build_args = ['--config', cfg]

if platform.system() == "Windows":
@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << ";\n";
}

inline std::string preheader() {
return
R"(
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
#define __readonly __attribute__((readonly))
#define __writeonly __attribute__((writeonly))
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
extern int get_program_id(int);
)";
}

void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = preheader() + src;
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty())
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "double";
if(ty->is_float_ty())
return "float32";
return "double";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "torch::Tensor";
throw std::runtime_error("unknown type");
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(64))
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "half";
if(ty->is_float_ty())
return "float32";
return "float";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "drv::cu_buffer";
throw std::runtime_error("unknown type");
@@ -1,2 +1,3 @@
from .dot import _dot, dot
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm
@@ -0,0 +1,75 @@
import triton
import math

class _batchnorm(triton.function):

fwd_src = """
void batchnormForward(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float rcpN, float eps) {
int rx[TM] = 0 ... TM;
float *px[TM];
float x[TM] = 0;
int c = get_program_id(1);
float g = *(G + c);
float b = *(B + c);
float mean[TM] = 0;
px = X + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
mean = mean + x;
px = px + TM;
}
float *pm = M + c;
float m = mean[+] * rcpN;
*pm = m;
float var[TM] = 0;
px = X + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
x = x - m;
var = var + x*x;
px = px + TM;
}
float v = var[+] * rcpN;
float *pv = V + c;
*pv = v;
float rstdg = 1 / sqrtf(v + eps) * g;
px = X + rx + c*N;
float* py[TM] = Y + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
float y[TM] = (x - m)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
}
}
"""

fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])

@staticmethod
def forward(ctx, x, gamma, beta, eps):
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
N = H*W*B
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
lambda opt: [1, C],
TM = 128)
# save
ctx.eps = eps
ctx.save_for_backward(x, gamma, beta, mean, var)
return y, mean, var


batchnorm = _batchnorm.apply
@@ -181,22 +181,24 @@ def call(a, b, trans_a, trans_b, shape_c, bmnk,
@staticmethod
def forward(ctx, subscripts, a, b, bench = 0):
ctx.save_for_backward(a, b)
# parse
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
einsum_b, einsum_c = einsum_bc.split("->")
else:
einsum_a, einsum_b, einsum_c = subscripts

shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
triton.shape(a), triton.shape(b))
# save for backward
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
ctx.einsum_b = einsum_b
ctx.einsum_c = einsum_c
ctx.bench = bench
ctx.bmnk = bmnk
# run
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)


0 comments on commit d65a94c

Please sign in to comment.
You can’t perform that action at this time.