-
Notifications
You must be signed in to change notification settings - Fork 0
/
FlattenTable.lua
106 lines (95 loc) · 3.18 KB
/
FlattenTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
local FlattenTable, parent = torch.class('nn.FlattenTable', 'nn.Module')
function FlattenTable:__init()
parent.__init(self)
self.output = {}
self.input_map = {}
self.gradInput = {}
end
-- Recursive function to flatten a table (output is a table)
local function flatten(output, input)
local input_map -- has the same structure as input, but stores the
-- indices to the corresponding output
if torch.type(input) == 'table' then
input_map = {}
-- forward DFS order
for i = 1, #input do
input_map[#input_map+1] = flatten(output, input[i])
end
else
input_map = #output + 1
output[input_map] = input -- append the tensor
end
return input_map
end
-- Recursive function to check if we need to rebuild the output table
local function checkMapping(output, input, input_map)
if input_map == nil or output == nil or input == nil then
return false
end
if torch.type(input) == 'table' then
if torch.type(input_map) ~= 'table' then
return false
end
if #input ~= #input_map then
return false
end
-- forward DFS order
for i = 1, #input do
local ok = checkMapping(output, input[i], input_map[i])
if not ok then
return false
end
end
return true
else
if torch.type(input_map) ~= 'number' then
return false
end
return output[input_map] == input
end
end
-- During BPROP we have to build a gradInput with the same shape as the
-- input. This is a recursive function to build up a gradInput
local function inverseFlatten(gradOutput, input_map)
if torch.type(input_map) == 'table' then
local gradInput = {}
for i = 1, #input_map do
gradInput[#gradInput + 1] = inverseFlatten(gradOutput, input_map[i])
end
return gradInput
else
return gradOutput[input_map]
end
end
function FlattenTable:updateOutput(input)
assert(torch.type(input) == 'table', 'input must be a table')
-- to avoid updating rebuilding the flattened table every updateOutput call
-- we will do a DFS pass over the existing output table and the inputs to
-- see if it needs to be rebuilt.
if not checkMapping(self.output, input, self.input_map) then
self.output = {}
self.input_map = flatten(self.output, input)
end
return self.output
end
function FlattenTable:updateGradInput(input, gradOutput)
assert(torch.type(input) == 'table', 'input must be a table')
assert(torch.type(input) == 'table', 'gradOutput must be a table')
-- If the input changes between the updateOutput and updateGradInput call,
-- then we may have to rebuild the input_map! However, let's assume that
-- the input_map is valid and that forward has already been called.
-- However, we should check that the gradInput is valid:
if not checkMapping(gradOutput, self.gradInput, self.input_map) then
self.gradInput = inverseFlatten(gradOutput, self.input_map)
end
return self.gradInput
end
function FlattenTable:type(type, tensorCache)
-- This function just stores references so we don't need to do any type
-- conversions. Just force the tables to be empty.
self:clearState()
end
function FlattenTable:clearState()
self.input_map = {}
return parent.clearState(self)
end