Skip to content

Commit

Permalink
template work
Browse files Browse the repository at this point in the history
  • Loading branch information
wickedfoo committed Jun 10, 2016
1 parent f545703 commit 50107ba
Show file tree
Hide file tree
Showing 51 changed files with 2,809 additions and 1,799 deletions.
219 changes: 184 additions & 35 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@ static int torch_isnonemptytable(lua_State *L, int idx)
}
]])

-- specific to CUDA
local typename = 'CudaTensor'

-- Lua 5.2 compatibility
local unpack = unpack or table.unpack

-- specific to CUDA
local typenames = {'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaTensor',
'CudaDoubleTensor'}

for _, typename in ipairs(typenames) do
-- cut and paste from wrap/types.lua
wrap.types.CudaTensor = {
wrap.types[typename] = {

helpname = function(arg)
if arg.dim then
Expand Down Expand Up @@ -120,7 +127,7 @@ wrap.types.CudaTensor = {
end
}

wrap.types.CudaTensorArray = {
wrap.types[typename .. 'Array'] = {

helpname = function(arg)
return string.format('{%s+}', typename)
Expand Down Expand Up @@ -181,7 +188,7 @@ wrap.types.CudaTensorArray = {
return string.format('THFree(arg%d_data);', arg.i)
end
}

end

wrap.types.LongArg = {

Expand Down Expand Up @@ -289,26 +296,6 @@ wrap.types.charoption = {
end
}

function interface.luaname2wrapname(self, name)
return string.format('cutorch_CudaTensor_%s', name)
end

local function cname(name)
return string.format('THCudaTensor_%s', name)
end

local function lastdim(argn)
return function(arg)
return string.format("THCudaTensor_nDimension(cutorch_getstate(L), %s)", arg.args[argn]:carg())
end
end

local function lastdimarray(argn)
return function(arg)
return string.format("THCudaTensor_nDimension(cutorch_getstate(L), arg%d_data[0])", arg.args[argn].i)
end
end

cutorch_state_code = function(varname)
local txt = {}
table.insert(txt, 'lua_getglobal(L, "cutorch");')
Expand Down Expand Up @@ -345,17 +332,179 @@ local function wrap(...)
method:wrap(unpack(args))
end

--
-- Non-CudaTensor type math, since these are less fully implemented than
-- CudaTensor
--

local handledTypenames = {'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaDoubleTensor'}
local handledTypereals = {'unsigned char',
'char',
'short',
'int',
'long',
'double'}

for k, Tensor in pairs(handledTypenames) do
local real = handledTypereals[k]

function interface.luaname2wrapname(self, name)
return string.format('cutorch_%s_%s', Tensor, name)
end

function method.luaname2wrapname(self, name)
return string.format('m_cutorch_%s_%s', Tensor, name)
end

local function cname(name)
return string.format('TH%s_%s', Tensor, name)
end

local function lastdim(argn)
return function(arg)
return string.format('TH%s_nDimension(cutorch_getstate(L), %s)',
Tensor, arg.args[argn]:carg())
end
end

local function lastdimarray(argn)
return function(arg)
return string.format('TH%s_nDimension(cutorch_getstate(L), arg%d_data[0])',
Tensor, arg.args[argn].i)
end
end

wrap("fill",
cname("fill"),
{{name=Tensor, returned=true},
{name=real}})

wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})

wrap("zeros",
cname("zeros"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name="LongArg"}})

wrap("ones",
cname("ones"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name="LongArg"}})

wrap("reshape",
cname("reshape"),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name="LongArg"}})

wrap("numel",
cname("numel"),
{{name=Tensor},
{name="long", creturned=true}})

wrap("add",
cname("add"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}},
cname("cadd"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real, default=1},
{name=Tensor}})

wrap("csub",
cname("sub"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}},
cname("csub"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real, default=1},
{name=Tensor}})

for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=Tensor}})
end

method:register("m_cutorch_" .. Tensor .. "Math__")
interface:print(method:tostring())
method:clearhistory()
method:registerDefaultArgument(cutorch_state_code)
interface:register("cutorch_" .. Tensor .. "Math__")

interface:print(string.format([[
void cutorch_%sMath_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.%s");
/* register methods */
luaL_setfuncs(L, m_cutorch_%sMath__, 0);
/* register functions into the "torch" field of the tensor metaclass */
lua_pushstring(L, "torch");
lua_newtable(L);
luaL_setfuncs(L, cutorch_%sMath__, 0);
lua_rawset(L, -3);
lua_pop(L, 1);
}
]], Tensor, Tensor, Tensor, Tensor))
end


--
-- CudaTensor special handling, since it is more fully implemented
--

local Tensor = "CudaTensor"
local real = "float"

function interface.luaname2wrapname(self, name)
return string.format('cutorch_%s_%s', Tensor, name)
end

function method.luaname2wrapname(self, name)
return string.format('m_cutorch_%s_%s', Tensor, name)
end

local function cname(name)
return string.format('TH%s_%s', Tensor, name)
end

local function lastdim(argn)
return function(arg)
return string.format('TH%s_nDimension(cutorch_getstate(L), %s)',
Tensor, arg.args[argn]:carg())
end
end

local function lastdimarray(argn)
return function(arg)
return string.format('TH%s_nDimension(cutorch_getstate(L), arg%d_data[0])',
Tensor, arg.args[argn].i)
end
end

wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})

wrap("fill",
cname("fill"),
{{name=Tensor, returned=true},
{name=real}})
{name=real}})

wrap("zeros",
cname("zeros"),
Expand Down Expand Up @@ -1012,26 +1161,26 @@ wrap("squeeze",
end},
{name="index"}})

method:register("m_cutorch_CudaTensorMath__")
method:register("m_cutorch_" .. Tensor .. "Math__")
interface:print(method:tostring())
method:clearhistory()
interface:register("cutorch_CudaTensorMath__")
interface:register("cutorch_" .. Tensor .. "Math__")

interface:print([[
void cutorch_CudaTensorMath_init(lua_State *L)
interface:print(string.format([[
void cutorch_%sMath_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_pushmetatable(L, "torch.%s");
/* register methods */
luaL_setfuncs(L, m_cutorch_CudaTensorMath__, 0);
luaL_setfuncs(L, m_cutorch_%sMath__, 0);
/* register functions into the "torch" field of the tensor metaclass */
lua_pushstring(L, "torch");
lua_newtable(L);
luaL_setfuncs(L, cutorch_CudaTensorMath__, 0);
luaL_setfuncs(L, cutorch_%sMath__, 0);
lua_rawset(L, -3);
lua_pop(L, 1);
}
]])
]], Tensor, Tensor, Tensor, Tensor))

interface:tofile(arg[1])
6 changes: 4 additions & 2 deletions generic/CStorage.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/CStorage.c"
#else

#include "THCHalf.h"

/* everything is as the generic Storage.c, except few things (see below) */

#ifndef THC_REAL_IS_HALF
Expand Down Expand Up @@ -66,7 +68,7 @@ static int cutorch_Storage_(copy)(lua_State *L)
THCStorage_(copyCudaFloat)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
THCStorage_(copyCudaDouble)(state, storage, src);
#if CUDA_VERSION >= 7050
#ifdef CUDA_HALF_TENSOR
else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) )
THCStorage_(copyCudaHalf)(state, storage, src);
#endif
Expand Down Expand Up @@ -127,7 +129,7 @@ static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L)
THStorage_(copyCudaInt)(cutorch_getstate(L), storage, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
THStorage_(copyCudaDouble)(cutorch_getstate(L), storage, src);
#if CUDA_VERSION >= 7050
#ifdef CUDA_HALF_TENSOR
else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) )
THStorage_(copyCudaHalf)(cutorch_getstate(L), storage, src);
#endif
Expand Down
6 changes: 4 additions & 2 deletions generic/CTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/CTensor.c"
#else

#include "THCHalf.h"

/* everything is as the generic Storage.c, except few things (see below) */

#define TH_GENERIC_FILE "generic/Tensor.c"
Expand All @@ -28,7 +30,7 @@ static int cutorch_Tensor_(copy)(lua_State *L)
THCTensor_(copyCudaLong)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
THCTensor_(copyCudaDouble)(state, tensor, src);
#if CUDA_VERSION >= 7050
#ifdef CUDA_HALF_TENSOR
else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) )
THCTensor_(copyCudaHalf)(state, tensor, src);
#endif
Expand Down Expand Up @@ -110,7 +112,7 @@ static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L)
THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src);
#if CUDA_VERSION >= 7050
#ifdef CUDA_HALF_TENSOR
else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) )
THTensor_(copyCudaHalf)(cutorch_getstate(L), tensor, src);
#endif
Expand Down
Loading

0 comments on commit 50107ba

Please sign in to comment.