Permalink
Browse files

Merge branch 'master' into contiguous-cat-1d

  • Loading branch information...
2 parents d41580e + be986fb commit 306984ff623886b8c07c8934623bd94fb52ed983 @soumith soumith committed on GitHub Jan 1, 2017
View
@@ -15,14 +15,16 @@ local function checkArgumentType(expected, actual, fn, ud, level)
end
if ok then
+
local Real2real = {
Byte='unsigned char',
Char='char',
Short='short',
Int='int',
Long='long',
Float='float',
- Double='double'
+ Double='double',
+ Half='THHalf'
}
-- Allocator
@@ -34,6 +36,14 @@ typedef struct THAllocator {
} THAllocator;
]]
+ -- Half
+ ffi.cdef[[
+typedef struct {
+ unsigned short x;
+} __THHalf;
+typedef __THHalf THHalf;
+]]
+
-- Storage
for Real, real in pairs(Real2real) do
@@ -76,7 +86,7 @@ typedef struct THRealTensor
long *size;
long *stride;
int nDimension;
-
+
THRealStorage *storage;
ptrdiff_t storageOffset;
int refcount;
@@ -88,7 +98,8 @@ typedef struct THRealTensor
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
- local Tensor = torch.getmetatable(string.format('torch.%sTensor', Real))
+ local Tensor_type = string.format('torch.%sTensor', Real)
+ local Tensor = torch.getmetatable(Tensor_type)
local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')
rawset(Tensor,
@@ -107,75 +118,77 @@ typedef struct THRealTensor
end)
-- faster apply (contiguous case)
- local apply = Tensor.apply
- rawset(Tensor,
- "apply",
- function(self, func)
- if self:isContiguous() and self.data then
- local self_d = self:data()
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ if Tensor_type ~= 'torch.HalfTensor' then
+ local apply = Tensor.apply
+ rawset(Tensor,
+ "apply",
+ function(self, func)
+ if self:isContiguous() and self.data then
+ local self_d = self:data()
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return apply(self, func)
end
- return self
- else
- return apply(self, func)
- end
- end)
-
- -- faster map (contiguous case)
- local map = Tensor.map
- rawset(Tensor,
- "map",
- function(self, src, func)
- checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
- checkArgumentType(self:type(), src:type(), "map", 1)
-
- if self:isContiguous() and src:isContiguous() and self.data and src.data then
- local self_d = self:data()
- local src_d = src:data()
- assert(src:nElement() == self:nElement(), 'size mismatch')
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ end)
+
+ -- faster map (contiguous case)
+ local map = Tensor.map
+ rawset(Tensor,
+ "map",
+ function(self, src, func)
+ checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
+ checkArgumentType(self:type(), src:type(), "map", 1)
+
+ if self:isContiguous() and src:isContiguous() and self.data and src.data then
+ local self_d = self:data()
+ local src_d = src:data()
+ assert(src:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return map(self, src, func)
end
- return self
- else
- return map(self, src, func)
- end
- end)
-
- -- faster map2 (contiguous case)
- local map2 = Tensor.map2
- rawset(Tensor,
- "map2",
- function(self, src1, src2, func)
- checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
- checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
- checkArgumentType(self:type(), src1:type(), "map", 1)
- checkArgumentType(self:type(), src2:type(), "map", 2)
-
- if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
- local self_d = self:data()
- local src1_d = src1:data()
- local src2_d = src2:data()
- assert(src1:nElement() == self:nElement(), 'size mismatch')
- assert(src2:nElement() == self:nElement(), 'size mismatch')
- for i=0,self:nElement()-1 do
- local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
- if res then
- self_d[i] = res
+ end)
+
+ -- faster map2 (contiguous case)
+ local map2 = Tensor.map2
+ rawset(Tensor,
+ "map2",
+ function(self, src1, src2, func)
+ checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
+ checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
+ checkArgumentType(self:type(), src1:type(), "map", 1)
+ checkArgumentType(self:type(), src2:type(), "map", 2)
+
+ if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
+ local self_d = self:data()
+ local src1_d = src1:data()
+ local src2_d = src2:data()
+ assert(src1:nElement() == self:nElement(), 'size mismatch')
+ assert(src2:nElement() == self:nElement(), 'size mismatch')
+ for i=0,self:nElement()-1 do
+ local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
+ if res then
+ self_d[i] = res
+ end
end
+ return self
+ else
+ return map2(self, src1, src2, func)
end
- return self
- else
- return map2(self, src1, src2, func)
- end
- end)
+ end)
+ end
end
-- torch.data
View
@@ -7,3 +7,6 @@
#include "generic/Storage.c"
#include "THGenerateAllTypes.h"
+
+#include "generic/Storage.c"
+#include "THGenerateHalfType.h"
View
@@ -7,3 +7,6 @@
#include "generic/Tensor.c"
#include "THGenerateAllTypes.h"
+
+#include "generic/Tensor.c"
+#include "THGenerateHalfType.h"
View
@@ -5,14 +5,14 @@ local Storage = {}
local Tensor = {}
-- types
-local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
+local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}
-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end
-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
- if self:size() == 0 then
+ if self:size() == 0 then
return "", nil, 0
end
local intMode = true
@@ -277,6 +277,10 @@ function Tensor.double(self)
return self:type('torch.DoubleTensor')
end
+function Tensor.half(self)
+ return self:type('torch.HalfTensor')
+end
+
function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
@@ -556,6 +560,14 @@ torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
- rawset(metatable, funcname, func)
+ if funcname ~= 'totable' or type ~='Half' then
+ rawset(metatable, funcname, func)
+ else
+ local function Tensor__totable(self)
+ local host_tensor = self:float()
+ return self:float():totable()
+ end
+ rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable)
+ end
end
end
View
@@ -6,56 +6,7 @@ local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
local argtypes = wrap.CInterface.argtypes
-argtypes['ptrdiff_t'] = {
-
- helpname = function(arg)
- return 'ptrdiff_t'
- end,
-
- declare = function(arg)
- -- if it is a number we initialize here
- local default = tonumber(tostring(arg.default)) or 0
- return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
- end,
-
- check = function(arg, idx)
- return string.format("lua_isnumber(L, %d)", idx)
- end,
-
- read = function(arg, idx)
- return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx)
- end,
-
- init = function(arg)
- -- otherwise do it here
- if arg.default then
- local default = tostring(arg.default)
- if not tonumber(default) then
- return string.format("arg%d = %s;", arg.i, default)
- end
- end
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- if arg.returned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end,
-
- postcall = function(arg)
- if arg.creturned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end
-}
+argtypes['ptrdiff_t'] = wrap.types.ptrdiff_t
interface:print([[
#include "TH.h"
@@ -216,6 +167,7 @@ local reals = {ByteTensor='unsigned char',
IntTensor='int',
LongTensor='long',
FloatTensor='float',
+ HalfTensor='half',
DoubleTensor='double'}
local accreals = {ByteTensor='long',
@@ -224,11 +176,12 @@ local accreals = {ByteTensor='long',
IntTensor='long',
LongTensor='long',
FloatTensor='double',
+ HalfTensor='float',
DoubleTensor='double'}
for _,Tensor in ipairs({"ByteTensor", "CharTensor",
"ShortTensor", "IntTensor", "LongTensor",
- "FloatTensor", "DoubleTensor"}) do
+ "FloatTensor", "HalfTensor", "DoubleTensor"}) do
local real = reals[Tensor]
local accreal = accreals[Tensor]
@@ -257,6 +210,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
end
end
+ if Tensor ~= 'HalfTensor' then
wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})
@@ -1030,6 +984,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
cname("nonzero"),
{{name="IndexTensor", default=true, returned=true},
{name=Tensor}})
+ end -- ~= HalfTensor
if Tensor == 'ByteTensor' then
-- Logical accumulators only apply to ByteTensor
@@ -1089,6 +1044,14 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
{name="double",default=0},
{name="double",default=0}})
+ wrap("bhistc",
+ cname("bhistc"),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor},
+ {name="long",default=100},
+ {name="double",default=0},
+ {name="double",default=0}})
+
wrap("norm",
cname("normall"),
{{name=Tensor},
View
@@ -687,6 +687,7 @@ local typesMatching = {
['torch.LongStorage'] = torch.LongTensor,
['torch.FloatStorage'] = torch.FloatTensor,
['torch.DoubleStorage'] = torch.DoubleTensor,
+ ['torch.HalfStorage'] = torch.HalfTensor,
}
--[[ Tests for storage equality.
Oops, something went wrong.

0 comments on commit 306984f

Please sign in to comment.