Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConcatTable nested table input #38

Merged
merged 5 commits into from
Jul 17, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions ConcatTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,52 @@ function ConcatTable:updateOutput(input)
return self.output
end

local function retable(t1, t2, f)
for k, v in pairs(t2) do
if (torch.type(v) == "table") then
t1[k] = retable(t1[k] or {}, t2[k], f)
else
f(t1, k, v)
end
end
return t1
end

function ConcatTable:updateGradInput(input, gradOutput)
for i,module in ipairs(self.modules) do
local currentGradInput = module:updateGradInput(input, gradOutput[i])
if i == 1 then
if type(input) == 'table' then
assert(type(currentGradInput) == 'table',
'currentGradInput is not a table!')
assert(#input == #currentGradInput,
'table size mismatch')
-- gradInput is also a table
self.gradInput = {}
for j = 1, #currentGradInput do
self.gradInput[j] = currentGradInput[j]:clone()
end
local isTable = torch.type(input) == 'table'
local wasTable = torch.type(self.gradInput) == 'table'
if isTable then
for i,module in ipairs(self.modules) do
local currentGradInput = module:updateGradInput(input, gradOutput[i])
if torch.type(currentGradInput) ~= 'table' then
error"currentGradInput is not a table!"
end
if #input ~= #currentGradInput then
error("table size mismatch: "..#input.." ~= "..#currentGradInput)
end
if i == 1 then
self.gradInput = wasTable and self.gradInput or {}
retable(self.gradInput, currentGradInput,
function(t, k, v)
t[k] = t[k] or v:clone()
t[k]:resizeAs(v)
t[k]:copy(v)
end
)
else
-- gradInput is a tensor
self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
retable(self.gradInput, currentGradInput,
function(t, k, v)
t[k]:add(v)
end
)
end
else
if type(input) == 'table' then
assert(type(currentGradInput) == 'table',
'currentGradInput is not a table!')
assert(#input == #currentGradInput,
'table size mismatch')
for j = 1, #self.gradInput do
self.gradInput[j]:add(currentGradInput[j])
end
end
else
self.gradInput = (not wasTable) and self.gradInput or input:clone()
for i,module in ipairs(self.modules) do
local currentGradInput = module:updateGradInput(input, gradOutput[i])
if i == 1 then
self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
else
self.gradInput:add(currentGradInput)
end
Expand Down Expand Up @@ -126,6 +145,15 @@ function ConcatTable:parameters()
return w,gw
end

function ConcatTable:type(type)
parent.type(self, type)
if torch.type(self.gradInput) == 'table' then
for i, gradInput in ipairs(self.gradInput) do
self.gradInput[i] = gradInput:type(type)
end
end
end

function ConcatTable:__tostring__()
local tab = ' '
local line = '\n'
Expand Down
36 changes: 34 additions & 2 deletions doc/table.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ pred=mlp:forward{x,y,z} -- This is equivalent to the line before
## ConcatTable ##

ConcatTable is a container module that applies each member module to
the same input [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
the same input [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor)
or Table.

Example:
Example 1:
```lua
mlp= nn.ConcatTable()
mlp:add(nn.Linear(5,2))
Expand All @@ -60,6 +61,37 @@ which gives the output:
[torch.Tensor of dimension 3]
```

Example 2:
```lua
mlp= nn.ConcatTable()
mlp:add(nn.Identity())
mlp:add(nn.Identity())

pred=mlp:forward{torch.randn(2),{torch.randn(3)}};
print(pred)
```
which gives the output (using [th](https://github.com/torch/trepl)):
```lua
{
1 :
{
1 : DoubleTensor - size: 2
2 :
{
1 : DoubleTensor - size: 3
}
}
2 :
{
1 : DoubleTensor - size: 2
2 :
{
1 : DoubleTensor - size: 3
}
}
}

```
<a name="nn.ParallelTable"/>
## ParallelTable ##

Expand Down
72 changes: 33 additions & 39 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ local expprecision = 1e-4

local nntest = {}

local function equal(t1, t2, msg)
if (torch.type(t1) == "table") then
for k, v in pairs(t2) do
equal(t1[k], t2[k])
end
else
mytester:assertTensorEq(t1, t2, 0.00001, msg)
end
end

function nntest.Add()
local ini = math.random(10,20)
local inj = math.random(10,20)
Expand Down Expand Up @@ -1915,15 +1925,6 @@ function nntest.SelectTable()
{torch.Tensor(3,4,5):zero()},
{torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
}
local function equal(t1, t2, msg)
if (torch.type(t1) == "table") then
for k, v in pairs(t2) do
equal(t1[k], t2[k])
end
else
mytester:assertTensorEq(t1, t2, 0.00001, msg)
end
end
local nonIdx = {2,3,4,1}
local module
for idx = 1,#input do
Expand Down Expand Up @@ -2012,36 +2013,29 @@ function nntest.ConcatTable()
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')

-- Now test a table input
-- jac needs a tensor input, so we have to form a network that creates
-- a table internally: Do this using a Reshape and a SplitTable
m = nn.Sequential()
m:add(nn.Reshape(1,10,10,10))
m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1

concat = nn.ConcatTable()
concat:add(nn.JoinTable(1))

m:add(concat)
m:add(nn.JoinTable(1))

err = jac.testJacobian(m, input)
mytester:assertlt(err, precision, ' error on state ')

ferr, berr = jac.testIO(m, input)
mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')

-- As per Soumith's suggestion, make sure getParameters works:
m = nn.ConcatTable()
local l = nn.Linear(16,16)
m:add(l)
mparams = m:getParameters()
-- I don't know of a way to make sure that the storage is equal, however
-- the linear weight and bias will be randomly initialized, so just make
-- sure both parameter sets are equal
lparams = l:getParameters()
err = (mparams - lparams):abs():max()
mytester:assertlt(err, precision, ' getParameters error ')
local input = {
torch.randn(3,4):float(), torch.randn(3,4):float(), {torch.randn(3,4):float()}
}
local _gradOutput = {
torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float()
}
local gradOutput = {
{_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}},
{_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}},
{_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}}
}
local module = nn.ConcatTable()
module:add(nn.Identity())
module:add(nn.Identity())
module:add(nn.Identity())
module:float()

local output = module:forward(input)
local output2 = {input, input, input}
equal(output2, output, "ConcatTable table output")
local gradInput = module:backward(input, gradOutput)
local gradInput2 = {_gradOutput[1]:sum(1), _gradOutput[2]:sum(1), {_gradOutput[3]:sum(1)}}
equal(gradInput, gradInput2, "ConcatTable table gradInput")
end

mytester:add(nntest)
Expand Down