Permalink
574 lines (530 sloc) 16.3 KB
-- additional methods for Storage
local Storage = {}
-- additional methods for Tensor
local Tensor = {}
-- types
local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}
-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end
-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
if self:size() == 0 then
return "", nil, 0
end
local intMode = true
local type = torch.typename(self)
-- if type == 'torch.FloatStorage' or type == 'torch.DoubleStorage' then
for i=1,self:size() do
if self[i] ~= math.ceil(self[i]) then
intMode = false
break
end
end
-- end
local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs()
local expMin = tensor:min()
if expMin ~= 0 then
expMin = math.floor(log10(expMin)) + 1
else
expMin = 1
end
local expMax = tensor:max()
if expMax ~= 0 then
expMax = math.floor(log10(expMax)) + 1
else
expMax = 1
end
local format
local scale
local sz
if intMode then
if expMax > 9 then
format = "%11.4e"
sz = 11
else
format = "%SZd"
sz = expMax + 1
end
else
if expMax-expMin > 4 then
format = "%SZ.4e"
sz = 11
if math.abs(expMax) > 99 or math.abs(expMin) > 99 then
sz = sz + 1
end
else
if expMax > 5 or expMax < 0 then
format = "%SZ.4f"
sz = 7
scale = math.pow(10, expMax-1)
else
format = "%SZ.4f"
if expMax == 0 then
sz = 7
else
sz = expMax+6
end
end
end
end
format = string.gsub(format, 'SZ', sz)
if scale == 1 then
scale = nil
end
return format, scale, sz
end
function Storage.__tostring__(self)
local strt = {}
local format,scale = Storage__printformat(self)
if format:sub(2,4) == 'nan' then format = '%f' end
if scale then
table.insert(strt, string.format('%g', scale) .. ' *\n')
for i = 1,self:size() do
table.insert(strt, string.format(format, self[i]/scale) .. '\n')
end
else
for i = 1,self:size() do
table.insert(strt, string.format(format, self[i]) .. '\n')
end
end
table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. self:size() .. ']\n')
local str = table.concat(strt)
return str
end
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Storage')
for funcname, func in pairs(Storage) do
rawset(metatable, funcname, func)
end
end
local function Tensor__printMatrix(self, indent)
local format,scale,sz = Storage__printformat(self:storage())
if format:sub(2,4) == 'nan' then format = '%f' end
-- print('format = ' .. format)
scale = scale or 1
indent = indent or ''
local strt = {indent}
local nColumnPerLine = math.floor((80-#indent)/(sz+1))
-- print('sz = ' .. sz .. ' and nColumnPerLine = ' .. nColumnPerLine)
local firstColumn = 1
local lastColumn = -1
while firstColumn <= self:size(2) do
if firstColumn + nColumnPerLine - 1 <= self:size(2) then
lastColumn = firstColumn + nColumnPerLine - 1
else
lastColumn = self:size(2)
end
if nColumnPerLine < self:size(2) then
if firstColumn ~= 1 then
table.insert(strt, '\n')
end
table.insert(strt, 'Columns ' .. firstColumn .. ' to ' .. lastColumn .. '\n' .. indent)
end
if scale ~= 1 then
table.insert(strt, string.format('%g', scale) .. ' *\n ' .. indent)
end
for l=1,self:size(1) do
local row = self:select(1, l)
for c=firstColumn,lastColumn do
table.insert(strt, string.format(format, row[c]/scale))
if c == lastColumn then
table.insert(strt, '\n')
if l~=self:size(1) then
if scale ~= 1 then
table.insert(strt, indent .. ' ')
else
table.insert(strt, indent)
end
end
else
table.insert(strt, ' ')
end
end
end
firstColumn = lastColumn + 1
end
local str = table.concat(strt)
return str
end
local function Tensor__printTensor(self)
local counter = torch.LongStorage(self:nDimension()-2)
local strt = {''}
local finished
counter:fill(1)
counter[1] = 0
while true do
for i=1,self:nDimension()-2 do
counter[i] = counter[i] + 1
if counter[i] > self:size(i) then
if i == self:nDimension()-2 then
finished = true
break
end
counter[i] = 1
else
break
end
end
if finished then
break
end
-- print(counter)
if #strt > 1 then
table.insert(strt, '\n')
end
table.insert(strt, '(')
local tensor = self
for i=1,self:nDimension()-2 do
tensor = tensor:select(1, counter[i])
table.insert(strt, counter[i] .. ',')
end
table.insert(strt, '.,.) = \n')
table.insert(strt, Tensor__printMatrix(tensor, ' '))
end
return table.concat(strt)
end
function Tensor.__tostring__(self)
local strt = {''}
if self:nDimension() == 0 then
table.insert(strt, '[' .. torch.typename(self) .. ' with no dimension]\n')
else
local tensor = torch.DoubleTensor():resize(self:size()):copy(self)
if tensor:nDimension() == 1 then
local format,scale,sz = Storage__printformat(tensor:storage())
if format:sub(2,4) == 'nan' then format = '%f' end
if scale then
table.insert(strt, string.format('%g', scale) .. ' *\n')
for i = 1,tensor:size(1) do
table.insert(strt, string.format(format, tensor[i]/scale) .. '\n')
end
else
for i = 1,tensor:size(1) do
table.insert(strt, string.format(format, tensor[i]) .. '\n')
end
end
table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. tensor:size(1) .. ']\n')
elseif tensor:nDimension() == 2 then
table.insert(strt, Tensor__printMatrix(tensor))
table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. tensor:size(1) .. 'x' .. tensor:size(2) .. ']\n')
else
table.insert(strt, Tensor__printTensor(tensor))
table.insert(strt, '[' .. torch.typename(self) .. ' of size ')
for i=1,tensor:nDimension() do
table.insert(strt, tensor:size(i))
if i ~= tensor:nDimension() then
table.insert(strt, 'x')
end
end
table.insert(strt, ']\n')
end
end
return table.concat(strt)
end
function Tensor.type(self,type)
local current = torch.typename(self)
if not type then return current end
if type ~= current then
local new = torch.getmetatable(type).new()
if self:nElement() > 0 then
new:resize(self:size()):copy(self)
end
return new
else
return self
end
end
function Tensor.typeAs(self,tensor)
return self:type(tensor:type())
end
function Tensor.byte(self)
return self:type('torch.ByteTensor')
end
function Tensor.char(self)
return self:type('torch.CharTensor')
end
function Tensor.short(self)
return self:type('torch.ShortTensor')
end
function Tensor.int(self)
return self:type('torch.IntTensor')
end
function Tensor.long(self)
return self:type('torch.LongTensor')
end
function Tensor.float(self)
return self:type('torch.FloatTensor')
end
function Tensor.double(self)
return self:type('torch.DoubleTensor')
end
function Tensor.half(self)
return self:type('torch.HalfTensor')
end
function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
function Tensor.expand(result,tensor,...)
-- get sizes
local sizes = {...}
local t = torch.type(tensor)
if (t == 'number' or t == 'torch.LongStorage') then
table.insert(sizes,1,tensor)
tensor = result
result = tensor.new()
end
-- check type
local size
if torch.type(sizes[1])=='torch.LongStorage' then
size = sizes[1]
else
size = torch.LongStorage(#sizes)
for i,s in ipairs(sizes) do
size[i] = s
end
end
-- get dimensions
local tensor_dim = tensor:dim()
local tensor_stride = tensor:stride()
local tensor_size = tensor:size()
-- check nb of dimensions
if #size ~= tensor:dim() then
error('the number of dimensions provided must equal tensor:dim()')
end
-- create a new geometry for tensor:
for i = 1,tensor_dim do
if tensor_size[i] == 1 then
tensor_size[i] = size[i]
tensor_stride[i] = 0
elseif tensor_size[i] ~= size[i] then
error('incorrect size: only supporting singleton expansion (size=1)')
end
end
-- create new view, with singleton expansion:
result:set(tensor:storage(), tensor:storageOffset(),
tensor_size, tensor_stride)
return result
end
torch.expand = Tensor.expand
function Tensor.expandAs(result,tensor,template)
if template then
return result:expand(tensor,template:size())
end
return result:expand(tensor:size())
end
torch.expandAs = Tensor.expandAs
function Tensor.repeatTensor(result,tensor,...)
-- get sizes
local sizes = {...}
local t = torch.type(tensor)
if (t == 'number' or t == 'torch.LongStorage') then
table.insert(sizes,1,tensor)
tensor = result
result = tensor.new()
end
-- if not contiguous, then force the tensor to be contiguous
if not tensor:isContiguous() then tensor = tensor:clone() end
-- check type
local size
if torch.type(sizes[1])=='torch.LongStorage' then
size = sizes[1]
else
size = torch.LongStorage(#sizes)
for i,s in ipairs(sizes) do
size[i] = s
end
end
if size:size() < tensor:dim() then
error('Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor')
end
local xtensor = tensor.new():set(tensor)
local xsize = xtensor:size():totable()
for i=1,size:size()-tensor:dim() do
table.insert(xsize,1,1)
end
size = torch.DoubleTensor(xsize):cmul(torch.DoubleTensor(size:totable())):long():storage()
xtensor:resize(torch.LongStorage(xsize))
result:resize(size)
local urtensor = result.new(result)
for i=1,xtensor:dim() do
urtensor = urtensor:unfold(i,xtensor:size(i),xtensor:size(i))
end
for i=1,urtensor:dim()-xtensor:dim() do
table.insert(xsize,1,1)
end
xtensor:resize(torch.LongStorage(xsize))
local xxtensor = xtensor:expandAs(urtensor)
urtensor:copy(xxtensor)
return result
end
torch.repeatTensor = Tensor.repeatTensor
--- One of the size elements can be -1,
--- a new LongStorage is then returned.
--- The length of the unspecified dimension
--- is inferred from the number of remaining elements.
local function specifyFully(size, nElements)
local nCoveredElements = 1
local remainingDim = nil
local sizes = size:totable()
for i = 1, #sizes do
local wantedDimSize = sizes[i]
if wantedDimSize == -1 then
if remainingDim then
error("Only one of torch.view dimensions can be -1.")
end
remainingDim = i
else
nCoveredElements = nCoveredElements * wantedDimSize
end
end
if not remainingDim then
return size
end
assert(nElements % nCoveredElements == 0, "The number of covered elements is not a multiple of all elements.")
local copy = torch.LongStorage(sizes)
copy[remainingDim] = nElements / nCoveredElements
return copy
end
-- TODO : This should be implemented in TH and and wrapped.
function Tensor.view(result, src, ...)
local size = ...
local view, tensor
local function istensor(tensor)
return torch.typename(tensor) and torch.typename(tensor):find('torch.*Tensor')
end
local function isstorage(storage)
return torch.typename(storage) and torch.typename(storage) == 'torch.LongStorage'
end
if istensor(result) and istensor(src) and type(size) == 'number' then
size = torch.LongStorage{...}
view = result
tensor = src
elseif istensor(result) and istensor(src) and isstorage(size) then
size = size
view = result
tensor = src
elseif istensor(result) and isstorage(src) and size == nil then
size = src
tensor = result
view = tensor.new()
elseif istensor(result) and type(src) == 'number' then
size = {...}
table.insert(size,1,src)
size = torch.LongStorage(size)
tensor = result
view = tensor.new()
else
local t1 = 'torch.Tensor, torch.Tensor, number [, number ]*'
local t2 = 'torch.Tensor, torch.Tensor, torch.LongStorage'
local t3 = 'torch.Tensor, torch.LongStorage'
local t4 = 'torch.Tensor, number [, number ]*'
error(string.format('torch.view, expected (%s) or\n (%s) or\n (%s)\n or (%s)', t1, t2, t3, t4))
end
local origNElement = tensor:nElement()
size = specifyFully(size, origNElement)
assert(tensor:isContiguous(), "expecting a contiguous tensor")
view:set(tensor:storage(), tensor:storageOffset(), size)
if view:nElement() ~= origNElement then
local inputSize = table.concat(tensor:size():totable(), "x")
local outputSize = table.concat(size:totable(), "x")
error(string.format("Wrong size for view. Input size: %s. Output size: %s",
inputSize, outputSize))
end
return view
end
torch.view = Tensor.view
function Tensor.viewAs(result, src, template)
if template and torch.typename(template) then
return result:view(src, template:size())
elseif template == nil then
template = src
src = result
result = src.new()
return result:view(src, template:size())
else
local t1 = 'torch.Tensor, torch.Tensor, torch.LongStorage'
local t2 = 'torch.Tensor, torch.LongStorage'
error(string.format('expecting (%s) or (%s)', t1, t2))
end
end
torch.viewAs = Tensor.viewAs
function Tensor.split(result, tensor, splitSize, dim)
if torch.type(result) ~= 'table' then
dim = splitSize
splitSize = tensor
tensor = result
result = {}
else
-- empty existing result table before using it
for k,v in pairs(result) do
result[k] = nil
end
end
dim = dim or 1
local start = 1
while start <= tensor:size(dim) do
local size = math.min(splitSize, tensor:size(dim) - start + 1)
local split = tensor:narrow(dim, start, size)
table.insert(result, split)
start = start + size
end
return result
end
torch.split = Tensor.split
function Tensor.chunk(result, tensor, nChunk, dim)
if torch.type(result) ~= 'table' then
dim = nChunk
nChunk = tensor
tensor = result
result = {}
end
dim = dim or 1
local splitSize = math.ceil(tensor:size(dim)/nChunk)
return torch.split(result, tensor, splitSize, dim)
end
torch.chunk = Tensor.chunk
function Tensor.totable(tensor)
local result = {}
local dim = tensor:dim()
if dim == 1 then
tensor:apply(function(i) table.insert(result, i) end)
elseif dim > 0 then
for i = 1, tensor:size(1) do
table.insert(result, tensor[i]:totable())
end
end
return result
end
torch.totable = Tensor.totable
function Tensor.permute(tensor, ...)
local perm = {...}
local nDims = tensor:dim()
assert(#perm == nDims, 'Invalid permutation')
local j
for i, p in ipairs(perm) do
if p ~= i and p ~= 0 then
j = i
repeat
assert(0 < perm[j] and perm[j] <= nDims, 'Invalid permutation')
tensor = tensor:transpose(j, perm[j])
j, perm[j] = perm[j], 0
until perm[j] == i
perm[j] = j
end
end
return tensor
end
torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
if funcname ~= 'totable' or type ~='Half' then
rawset(metatable, funcname, func)
else
local function Tensor__totable(self)
local host_tensor = self:float()
return self:float():totable()
end
rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable)
end
end
end