Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Commit

Permalink
Nested gradient support.
Browse files Browse the repository at this point in the history
  • Loading branch information
lukealonso committed Dec 8, 2015
1 parent 55e20ee commit cd6b9e0
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 30 deletions.
30 changes: 14 additions & 16 deletions src/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,20 @@ local function optimize(opt)
defaultOptimize = opt
end

local defaultProtected = false
local function protected(prot)
defaultProtected = prot
end

local function grad(fn, gradOpt)
gradOpt = gradOpt or { }
local argnum = gradOpt.gradArg or 1
local optimize = util.defaultBool(gradOpt.optimize, defaultOptimize)
local withForward = util.defaultBool(gradOpt.withForward, true)
local withGradients = util.defaultBool(gradOpt.withGradients, true)
local partialGrad = util.defaultBool(gradOpt.partialGrad, false)
local debugHook = gradOpt.debugHook
local signatureFn = gradOpt.signatureFn
local opt = {
argnum = argnum,
withForward = withForward,
withGradients = withGradients,
partialGrad = partialGrad,
debugHook = debugHook,
signatureFn = signatureFn
}
local opt = util.deepCopy(gradOpt)
opt.argnum = opt.gradArg or 1
opt.optimize = util.defaultBool(opt.optimize, defaultOptimize)
opt.protected = util.defaultBool(opt.protected, defaultProtected)
opt.withForward = util.defaultBool(opt.withForward, true)
opt.withGradients = util.defaultBool(opt.withGradients, true)
opt.partialGrad = util.defaultBool(opt.partialGrad, false)
if optimize then
return RuntimeCodegen.create(fn, opt)
else
Expand All @@ -42,7 +39,8 @@ end
local autograd = {
grad = grad,
overload = overload,
optimize = optimize
optimize = optimize,
protected = protected
}

-- Shortcut:
Expand Down
10 changes: 10 additions & 0 deletions src/overload.lua
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ local function module(name, table, fn)
end
end

local installDepth = 0

local function install(fn)
installDepth = installDepth + 1
if installDepth ~= 1 then
return
end
if #toRegister > 0 then
for i = 1, #toRegister do
toRegister[i]()
Expand All @@ -176,6 +182,10 @@ local function install(fn)
end

local function uninstall()
installDepth = installDepth - 1
if installDepth ~= 0 then
return
end
nnwrapper.setApplyFn(nil)
for i = 1, #overloads do
local mm = overloads[i]
Expand Down
22 changes: 16 additions & 6 deletions src/runtime/codegen/Graph.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ local util = require 'autograd.util'

local nodeDebugger
local applyDepth = 0
local nodeDisabled = true
local reentryDepth = 0

local function overloadHook(fn, gradFn, capture, ...)
local inputs = {...}
applyDepth = applyDepth + 1
if not nodeDisabled and applyDepth == 1 and capture then
if reentryDepth ~= 0 and applyDepth == 1 and capture then
local n = Node.new(fn, gradFn, inputs)
local values = {n:evaluateForward()}
if nodeDebugger then
Expand Down Expand Up @@ -248,6 +248,10 @@ function Graph:walkExecutionOrder(withForward, withGradients)
return execOrder, outputNodes
end

function Graph.reentryDepth()
return reentryDepth
end

function Graph.record(fn, args, opt)
local argnum = opt.argnum or 1
local debugger = opt.debugger
Expand All @@ -267,15 +271,14 @@ function Graph.record(fn, args, opt)
-- Begin recording all torch operations.
overload.install(overloadHook)
applyDepth = 0
nodeDisabled = false
reentryDepth = reentryDepth + 1
nodeDebugger = debugger

-- Call user forward function.
local answers = nil

local protectedFn = function()
answers = {fn(table.unpack(values))}

-- Figure out forward graph traversal order.
-- Only walk from the answer we need to differentiate (usually the first).
local forwardExecOrder = walkOutputRoots(answers[argnum])
Expand All @@ -295,13 +298,20 @@ function Graph.record(fn, args, opt)
node:evaluateBackward()
end
end
return true
end

local ok, msg = pcall(protectedFn)
local ok, msg

if opt.protected then
ok, msg = pcall(protectedFn)
else
ok, msg = protectedFn()
end

-- End recording.
nodeDebugger = nil
nodeDisabled = true
reentryDepth = reentryDepth - 1
overload.uninstall()

if not ok then
Expand Down
21 changes: 21 additions & 0 deletions src/runtime/codegen/Value.lua
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ function Value.flatten(v)
rawTable[k] = Value.flatten(v)
end
return rawTable
else
return v
end
elseif v.type == Value.TABLE then
return Value.flatten(v.raw)
Expand Down Expand Up @@ -121,6 +123,25 @@ function Value.flattenGrads(v)
end
end

function Value.collectGrads(v)
if not Value.isValue(v) then
if type(v) == "table" then
local rawTable = { }
for k,v in pairs(v) do
rawTable[k] = Value.collectGrads(v)
end
return rawTable
end
elseif v.type == Value.TABLE then
return Value.collectGrads(v.raw)
else
if v.source.gradients then
return v.source.gradients[1]
end
return nil
end
end

-- These exist only to be overloaded and called with flattened tensor or number arguments

function Value.__add(a, b)
Expand Down
19 changes: 12 additions & 7 deletions src/runtime/codegen/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ local function buildSignature(params, tensorDims)
end
end

local function execUncached(fn, args, opt)
local graph = createGraph(fn, args, opt, debugger)
local retValues = { Value.flattenGrads(graph.params[opt.argnum]) }
local function execUncached(fn, args, opt, nestedGradient)
local graph = Graph.record(fn, args, opt)
local retValues = { Value.collectGrads(graph.params[opt.argnum]) }
for i = 1, #graph.answers do
retValues[#retValues + 1] = flattenAnswer(graph.answers[i])
retValues[#retValues + 1] = graph.answers[i]
end
if not nestedGradient then
retValues = Value.flatten(retValues)
end
return table.unpack(retValues)
return unpack(retValues)
end

local function printPoolStats(tensorPool)
Expand Down Expand Up @@ -55,11 +58,13 @@ local function create(fn, opt)
return table.concat(tensorDims, "-")
end
local signature = sigFun(args)
if signature == nil then
return execUncached(fn, args, opt)
if signature == nil or Graph.reentryDepth() > 0 then
-- If we're in the middle of building the graph for a parent function, include this one in the parent, don't codegen.
return execUncached(fn, args, opt, Graph.reentryDepth() > 0)
end
if generatedFunctions[signature] == nil then
local gradFn, retValues, code = generateFn(fn, args, opt)
--print(code)
generatedFunctions[signature] = gradFn
-- We already have the answers, don't run it all over again.
if opt.withGradients and opt.withForward and not opt.debugHook then
Expand Down
9 changes: 8 additions & 1 deletion src/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,14 @@ function util.sortedFlatten(tbl, flat)
for k, v in pairs(tbl) do
keys[#keys + 1] = k
end
table.sort(keys)
local ok = pcall(function()
return table.sort(keys)
end)
if not ok then
table.sort(keys, function(a, b)
return tostring(a) < tostring(b)
end)
end
for i = 1, #keys do
local val = tbl[keys[i]]
if type(val) == "table" then
Expand Down
2 changes: 2 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ local gradcheckConstant = require 'autograd.gradcheck' {randomizeInput = false}
local tester = totem.Tester()
local stringx = require 'pl.stringx'

autograd.protected(true)

-- List of tests:
local tests = {
AutoModule = function()
Expand Down

0 comments on commit cd6b9e0

Please sign in to comment.