Permalink
Browse files

Add support for torch.HalfTensor (#874)

* Add support for torch.HalfTensor.

* Improvements/Simplifications for torch.HalfTensor.

Improvements/Simplifications:
1) Defines half type as TH_Half, so as to not conflict with cutorch
version.  Previously, these were defined as the same "half" type and
required proper ordering of includes to ensure type was only defined
once, which would have affected all downstream projects.
2) No longer generates math functions that are not actually defined
on torch.HalfTensor, e.g. maskedFill, map, etc.
3) Adds tests for all available torch.HalfTensor functions
4) Allows compiling without TH_GENERIC_USE_HALF (so if there's a
problem can just unset that in CMakeLists rather than backing out)
5) Some simplifications: removes a new copy optimization and
some TH_HALF literal definitions

Limitations:
Because match functions are not defined, some "non-math" operators
on torch.HalfTensor give an error message, e.g. __index__/__newindex__
with a ByteTensor apply a mask, but masks aren't implemented.  These
limitations aren't always obvious, (e.g. for documentation purposes),
but they should always give an error message.

* Rename TH_HALF to THHalf.
  • Loading branch information...
1 parent 7ca7ec9 commit a0c0b78471df5f4507791e870cf7df9607a64400 @gchanan gchanan committed with soumith Dec 29, 2016
View
@@ -25,6 +25,8 @@ IF(MSVC)
ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
ENDIF(MSVC)
+ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1)
+
# OpenMP support?
SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?")
IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
View
@@ -15,14 +15,16 @@ local function checkArgumentType(expected, actual, fn, ud, level)
end
if ok then
+
local Real2real = {
Byte='unsigned char',
Char='char',
Short='short',
Int='int',
Long='long',
Float='float',
- Double='double'
+ Double='double',
+ Half='THHalf'
}
-- Allocator
@@ -34,6 +36,14 @@ typedef struct THAllocator {
} THAllocator;
]]
+ -- Half
+ ffi.cdef[[
+typedef struct {
+ unsigned short x;
+} __THHalf;
+typedef __THHalf THHalf;
+]]
+
-- Storage
for Real, real in pairs(Real2real) do
@@ -76,7 +86,7 @@ typedef struct THRealTensor
long *size;
long *stride;
int nDimension;
-
+
THRealStorage *storage;
ptrdiff_t storageOffset;
int refcount;
@@ -88,7 +98,8 @@ typedef struct THRealTensor
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
- local Tensor = torch.getmetatable(string.format('torch.%sTensor', Real))
+ local Tensor_type = string.format('torch.%sTensor', Real)
+ local Tensor = torch.getmetatable(Tensor_type)
local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')
rawset(Tensor,
@@ -107,75 +118,77 @@ typedef struct THRealTensor
end)
-- faster apply (contiguous case)
- local apply = Tensor.apply
- rawset(Tensor,
- "apply",
- function(self, func)
- if self:isContiguous() and self.data then
- local self_d = self:data()
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ if Tensor_type ~= 'torch.HalfTensor' or torch.hashalfmath() then
+ local apply = Tensor.apply
+ rawset(Tensor,
+ "apply",
+ function(self, func)
+ if self:isContiguous() and self.data then
+ local self_d = self:data()
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return apply(self, func)
end
- return self
- else
- return apply(self, func)
- end
- end)
-
- -- faster map (contiguous case)
- local map = Tensor.map
- rawset(Tensor,
- "map",
- function(self, src, func)
- checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
- checkArgumentType(self:type(), src:type(), "map", 1)
-
- if self:isContiguous() and src:isContiguous() and self.data and src.data then
- local self_d = self:data()
- local src_d = src:data()
- assert(src:nElement() == self:nElement(), 'size mismatch')
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ end)
+
+ -- faster map (contiguous case)
+ local map = Tensor.map
+ rawset(Tensor,
+ "map",
+ function(self, src, func)
+ checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
+ checkArgumentType(self:type(), src:type(), "map", 1)
+
+ if self:isContiguous() and src:isContiguous() and self.data and src.data then
+ local self_d = self:data()
+ local src_d = src:data()
+ assert(src:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return map(self, src, func)
end
- return self
- else
- return map(self, src, func)
- end
- end)
-
- -- faster map2 (contiguous case)
- local map2 = Tensor.map2
- rawset(Tensor,
- "map2",
- function(self, src1, src2, func)
- checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
- checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
- checkArgumentType(self:type(), src1:type(), "map", 1)
- checkArgumentType(self:type(), src2:type(), "map", 2)
-
- if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
- local self_d = self:data()
- local src1_d = src1:data()
- local src2_d = src2:data()
- assert(src1:nElement() == self:nElement(), 'size mismatch')
- assert(src2:nElement() == self:nElement(), 'size mismatch')
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ end)
+
+ -- faster map2 (contiguous case)
+ local map2 = Tensor.map2
+ rawset(Tensor,
+ "map2",
+ function(self, src1, src2, func)
+ checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
+ checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
+ checkArgumentType(self:type(), src1:type(), "map", 1)
+ checkArgumentType(self:type(), src2:type(), "map", 2)
+
+ if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
+ local self_d = self:data()
+ local src1_d = src1:data()
+ local src2_d = src2:data()
+ assert(src1:nElement() == self:nElement(), 'size mismatch')
+ assert(src2:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return map2(self, src1, src2, func)
end
- return self
- else
- return map2(self, src1, src2, func)
- end
- end)
+ end)
+ end
end
-- torch.data
View
@@ -5,14 +5,14 @@ local Storage = {}
local Tensor = {}
-- types
-local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
+local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}
-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end
-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
- if self:size() == 0 then
+ if self:size() == 0 then
return "", nil, 0
end
local intMode = true
@@ -277,6 +277,10 @@ function Tensor.double(self)
return self:type('torch.DoubleTensor')
end
+function Tensor.half(self)
+ return self:type('torch.HalfTensor')
+end
+
function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
@@ -556,6 +560,14 @@ torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
- rawset(metatable, funcname, func)
+ if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then
+ rawset(metatable, funcname, func)
+ else
+ local function Tensor__totable(self)
+ local host_tensor = self:float()
+ return self:float():totable()
+ end
+ rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable)
+ end
end
end
View
@@ -6,56 +6,7 @@ local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
local argtypes = wrap.CInterface.argtypes
-argtypes['ptrdiff_t'] = {
-
- helpname = function(arg)
- return 'ptrdiff_t'
- end,
-
- declare = function(arg)
- -- if it is a number we initialize here
- local default = tonumber(tostring(arg.default)) or 0
- return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
- end,
-
- check = function(arg, idx)
- return string.format("lua_isnumber(L, %d)", idx)
- end,
-
- read = function(arg, idx)
- return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx)
- end,
-
- init = function(arg)
- -- otherwise do it here
- if arg.default then
- local default = tostring(arg.default)
- if not tonumber(default) then
- return string.format("arg%d = %s;", arg.i, default)
- end
- end
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- if arg.returned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end,
-
- postcall = function(arg)
- if arg.creturned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end
-}
+argtypes['ptrdiff_t'] = wrap.types.ptrdiff_t
interface:print([[
#include "TH.h"
@@ -216,6 +167,7 @@ local reals = {ByteTensor='unsigned char',
IntTensor='int',
LongTensor='long',
FloatTensor='float',
+ HalfTensor='half',
DoubleTensor='double'}
local accreals = {ByteTensor='long',
@@ -224,11 +176,12 @@ local accreals = {ByteTensor='long',
IntTensor='long',
LongTensor='long',
FloatTensor='double',
+ HalfTensor='float',
DoubleTensor='double'}
for _,Tensor in ipairs({"ByteTensor", "CharTensor",
"ShortTensor", "IntTensor", "LongTensor",
- "FloatTensor", "DoubleTensor"}) do
+ "FloatTensor", "HalfTensor", "DoubleTensor"}) do
local real = reals[Tensor]
local accreal = accreals[Tensor]
@@ -257,6 +210,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
end
end
+ if Tensor ~= 'HalfTensor' then
wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})
@@ -1030,6 +984,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
cname("nonzero"),
{{name="IndexTensor", default=true, returned=true},
{name=Tensor}})
+ end -- ~= HalfTensor
if Tensor == 'ByteTensor' then
-- Logical accumulators only apply to ByteTensor
@@ -1483,6 +1438,9 @@ void torch_TensorMath_init(lua_State *L)
torch_IntTensorMath_init(L);
torch_LongTensorMath_init(L);
torch_FloatTensorMath_init(L);
+ #if TH_NATIVE_HALF
+ torch_HalfTensorMath_init(L);
+ #endif
torch_DoubleTensorMath_init(L);
luaT_setfuncs(L, torch_TensorMath__, 0);
}
View
@@ -687,6 +687,7 @@ local typesMatching = {
['torch.LongStorage'] = torch.LongTensor,
['torch.FloatStorage'] = torch.FloatTensor,
['torch.DoubleStorage'] = torch.DoubleTensor,
+ ['torch.HalfStorage'] = torch.HalfTensor,
}
--[[ Tests for storage equality.
View
@@ -41,7 +41,7 @@ static int torch_Storage_(new)(lua_State *L)
THStorage_(free)(storage);
luaL_error(L, "element at index %d is not a number", i);
}
- THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1));
+ THStorage_(set)(storage, i-1, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1)));
lua_pop(L, 1);
}
}
@@ -131,6 +131,10 @@ static int torch_Storage_(copy)(lua_State *L)
THStorage_(copyFloat)(storage, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
THStorage_(copyDouble)(storage, src);
+#if TH_GENERIC_USE_HALF
+ else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) )
+ THStorage_(copyHalf)(storage, src);
+#endif
else
luaL_typerror(L, 2, "torch.*Storage");
lua_settop(L, 1);
Oops, something went wrong.

0 comments on commit a0c0b78

Please sign in to comment.