Permalink
219 lines (191 sloc) 6.4 KB
local ok, ffi = pcall(require, 'ffi')
local function checkArgument(condition, fn, ud, msg, level)
local level = level or 3
if not condition then
error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
end
end
local function checkArgumentType(expected, actual, fn, ud, level)
local level = level or 3
if expected ~= actual then
checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
end
end
if ok then
local Real2real = {
Byte='unsigned char',
Char='char',
Short='short',
Int='int',
Long='long',
Float='float',
Double='double',
Half='THHalf'
}
-- Allocator
ffi.cdef[[
typedef struct THAllocator {
void* (*malloc)(void*, ptrdiff_t);
void* (*realloc)(void*, void*, ptrdiff_t);
void (*free)(void*, void*);
} THAllocator;
]]
-- Half
ffi.cdef[[
typedef struct {
unsigned short x;
} __THHalf;
typedef __THHalf THHalf;
]]
-- Storage
for Real, real in pairs(Real2real) do
local cdefs = [[
typedef struct THRealStorage
{
real *data;
ptrdiff_t size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
} THRealStorage;
]]
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
local Storage = torch.getmetatable(string.format('torch.%sStorage', Real))
local Storage_tt = ffi.typeof('TH' .. Real .. 'Storage**')
rawset(Storage,
"cdata",
function(self)
return Storage_tt(self)[0]
end)
rawset(Storage,
"data",
function(self)
return Storage_tt(self)[0].data
end)
end
-- Tensor
for Real, real in pairs(Real2real) do
local cdefs = [[
typedef struct THRealTensor
{
long *size;
long *stride;
int nDimension;
THRealStorage *storage;
ptrdiff_t storageOffset;
int refcount;
char flag;
} THRealTensor;
]]
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
local Tensor_type = string.format('torch.%sTensor', Real)
local Tensor = torch.getmetatable(Tensor_type)
local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')
rawset(Tensor,
"cdata",
function(self)
if not self then return nil; end
return Tensor_tt(self)[0]
end)
rawset(Tensor,
"data",
function(self)
if not self then return nil; end
self = Tensor_tt(self)[0]
return self.storage ~= nil and self.storage.data + self.storageOffset or nil
end)
-- faster apply (contiguous case)
if Tensor_type ~= 'torch.HalfTensor' then
local apply = Tensor.apply
rawset(Tensor,
"apply",
function(self, func)
if self:isContiguous() and self.data then
local self_d = self:data()
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return apply(self, func)
end
end)
-- faster map (contiguous case)
local map = Tensor.map
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)
if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
assert(src:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map(self, src, func)
end
end)
-- faster map2 (contiguous case)
local map2 = Tensor.map2
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)
if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
local src2_d = src2:data()
assert(src1:nElement() == self:nElement(), 'size mismatch')
assert(src2:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map2(self, src1, src2, func)
end
end)
end
end
-- torch.data
-- will fail if :data() is not defined
function torch.data(self, asnumber)
if not self then return nil; end
local data = self:data()
if asnumber then
return ffi.cast('intptr_t', data)
else
return data
end
end
-- torch.cdata
-- will fail if :cdata() is not defined
function torch.cdata(self, asnumber)
if not self then return nil; end
local cdata = self:cdata()
if asnumber then
return ffi.cast('intptr_t', cdata)
else
return cdata
end
end
end