Permalink
Browse files

Add autograd

  • Loading branch information...
apaszke committed Aug 19, 2016
1 parent 78a958a commit 53f00ae429aa1bd18b407ffd17d06c9e85578edf
@@ -0,0 +1,3 @@
from .variable import Variable
from .function import Function
@@ -0,0 +1,80 @@
from collections import Counter
class ExecutionEngine(object):
def __init__(self):
pass
def _compute_dependencies(self, function):
dependencies = {}
seen = {function}
queue = [function]
while len(queue) > 0:
fn = queue.pop()
for prev_fn, arg_id in fn.previous_functions:
if prev_fn not in dependencies:
dependencies[prev_fn] = [Counter() for _ in prev_fn.output_ids]
output_idx = prev_fn.output_ids[arg_id]
dependencies[prev_fn][output_idx][fn] += 1
if prev_fn not in seen:
queue.append(prev_fn)
seen.add(prev_fn)
return dependencies
def _free_backward_dependency(self, dependencies, prev_fn, fn, arg_id):
deps = dependencies[prev_fn]
output_idx = prev_fn.output_ids[arg_id]
output_deps = deps[output_idx]
output_deps[fn] -= 1
if output_deps[fn] == 0:
del output_deps[fn]
return output_idx
def _is_ready_for_backward(self, dependencies, function):
for deps in dependencies[function]:
if len(deps) > 0:
return False
return True
def run_backward(self, variable, grad):
ready = [(variable.creator, (grad,))]
not_ready = {}
dependencies = self._compute_dependencies(variable.creator)
while len(ready) > 0:
fn, grad = ready.pop()
# TODO: double-buffering
grad_input = fn._do_backward(*grad)
for (prev_fn, arg_id), d_prev_fn in zip(fn.previous_functions, grad_input):
if not prev_fn.requires_grad:
assert d_prev_fn is None
continue
output_nr = self._free_backward_dependency(dependencies, prev_fn, fn, arg_id)
is_ready = self._is_ready_for_backward(dependencies, prev_fn)
if is_ready:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
if not prev_grad[output_nr]:
prev_grad[output_nr] = d_prev_fn
else:
prev_grad[output_nr].add_(d_prev_fn)
del not_ready[prev_fn]
else:
assert output_nr == 0
prev_grad = (d_prev_fn,)
ready.append((prev_fn, prev_grad))
else:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
else:
prev_grad = [None for _ in prev_fn.output_ids]
if not prev_grad[output_nr]:
prev_grad[output_nr] = d_prev_fn
else:
prev_grad[output_nr].add_(d_prev_fn)
not_ready[prev_fn] = prev_grad
@@ -0,0 +1,40 @@
from collections import Counter
from .variable import Variable
class Function(object):
def __init__(self):
self.previous_functions = None
self.output_ids = None
self.needs_input_grad = None
def __call__(self, *input):
return self._do_forward(*input)
def _do_forward(self, *input):
unpacked_input = tuple(arg.data for arg in input)
raw_output = self.forward(*unpacked_input)
if not isinstance(raw_output, tuple):
raw_output = (raw_output,)
self.needs_input_grad = tuple(arg.creator.requires_grad for arg in input)
self.requires_grad = any(self.needs_input_grad)
output = tuple(Variable(tensor, self) for tensor in raw_output)
self.previous_functions = [(arg.creator, id(arg)) for arg in input]
self.output_ids = {id(var): i for i, var in enumerate(output)}
return output
def _do_backward(self, grad_output, buffers=None):
grad_input = self.backward(grad_output)
if not isinstance(grad_input, tuple):
grad_input = (grad_input,)
assert len(grad_input) == len(self.previous_functions), \
self.__class__.__name__ + ' returned an invalid number of gradient tensors'
return grad_input
def forward(self, *input):
raise NotImplementedError
def backward(self, *grad_output):
raise NotImplementedError
@@ -0,0 +1,33 @@
import torch
from ..base.Variable import Variable
from ..base.Node import Node
class TorchNode(Node):
def __init__(self, input_variable, *args):
self.input_variable = input_variable
# TODO: check if Tensors aren't mixed with Variables
self.args = args
self.unpacked_args = tuple(self._unpack_arg(arg) for arg in args)
def _forward(self):
result = getattr(self.input_variable.data, self.fn_name)(*self.unpacked_args)
# If a function returns a number, we have to wrap it again
if not torch.isTensor(result):
result = self.input_variable.new((result,))
return Variable(result, self)
def _backward(self, grad_output, *args, **kwargs):
grad_input = self.backward(grad_output)
if not isinstance(grad_input, tuple):
grad_input = (grad_input,)
variables = (self.input_variable,) + tuple(filter(lambda x: isinstance(x, Variable), self.args))
assert isinstance(grad_input, tuple)
assert len(variables) == len(grad_input)
for var, d_var in zip(variables, grad_input):
var.backward(d_var, *args, **kwargs)
def _unpack_arg(self, arg):
if isinstance(arg, Variable):
return arg.data
return arg
@@ -0,0 +1,3 @@
from .basic_ops import *
from .tensor import *
from .pointwise import *
@@ -0,0 +1,29 @@
from ..TorchNode import TorchNode
from ...base.Variable import Variable
class Sum(TorchNode):
fn_name = 'sum'
def backward(self, grad_output):
i = self.input_variable.data
if len(self.args) == 0:
return i.new(i.size()).fill_(grad_output[0]),
elif len(self.args) == 1:
dim = self.args[0]
repeats = [1 for i in range(i.dim())]
repeats[dim] = i.size(dim)
return grad_output.repeatTensor(*repeats),
class Mean(TorchNode):
fn_name = 'mean'
def backward(self, grad_output):
i = inputs[0]
if len(self.args) == 0:
return i.new(i.size()).fill_(float(grad_output[0])/i.numel()),
elif len(self.args) == 1:
dim = self.args[0]
repeats = [1 for i in range(i.dim())]
repeats[dim] = i.size(dim)
return grad_output.repeatTensor(*repeats).div_(i.size(dim)),
@@ -0,0 +1,145 @@
import torch
from ..variable import Variable
from ..function import Function
class Add(Function):
def forward(self, a, b):
return a.add(b)
def backward(self, grad_output):
return grad_output, grad_output
class Sub(Function):
def forward(self, a, b):
return a.sub(b)
def backward(self, grad_output):
return grad_output, grad_output.neg()
class Mul(Function):
def forward(self, a, b):
self.input = (a, b)
return a.mul(b)
def backward(self, grad_output):
return grad_output.mul(self.input[1]), grad_output.mul(self.input[0])
class Div(Function):
def forward(self, a, b):
self.input = (a, b)
return a.div(b)
def backward(self, grad_output):
a, b = self.input
return grad_output.div(b), grad_output.neg().mul(a).div_(b).div_(b)
class Pow(Function):
def forward(self, a, b):
self.input = (a, b)
return a.pow(b)
def backward(self, grad_output):
a, b = self.input
return grad_output.mul(b).mul_(a.pow(b-1)), grad_output.mul(a.pow(b)).mul_(a.log())
class AddConstant(Function):
def __init__(self, constant):
self.constant = constant
def forward(self, a):
return a.add(self.constant)
def backward(self, grad_output):
return grad_output
class SubConstant(Function):
def __init__(self, constant, sub_tensor=False):
self.constant = constant
self.sub_tensor = sub_tensor
def forward(self, a):
if self.sub_tensor:
return a.new().resizeAs_(a).fill_(self.constant).sub_(a)
else:
return a.sub(self.constant)
def backward(self, grad_output):
if self.sub_tensor:
return grad_output.neg()
else:
return grad_output
class MulConstant(Function):
def __init__(self, constant):
self.constant = constant
def forward(self, a):
return a.mul(self.constant)
def backward(self, grad_output):
return grad_output.mul(self.constant)
class DivConstant(Function):
def __init__(self, constant, div_by_tensor=False):
self.constant = constant
self.div_by_tensor = div_by_tensor
def forward(self, a):
if self.div_by_tensor:
self.input = a
return a.new().resizeAs_(a).fill_(self.constant).div_(a)
else:
return a.div(self.constant)
def backward(self, grad_output):
if self.div_by_tensor:
a = self.input
return grad_output.neg().mul_(self.constant).div_(a).div_(a)
else:
return grad_output.div(self.constant)
class PowConstant(Function):
def __init__(self, constant, tensor_power=False):
self.constant = constant
self.tensor_power = tensor_power
def forward(self, a):
if self.tensor_power:
self.fw_result = torch.pow(self.constant, a)
return result
else:
self.input = a
return a.pow(self.constant)
def backward(self, grad_output):
if self.tensor_power:
return grad_output.mul(self.fw_result).mul_(math.log(self.constant))
else:
a = self.input
return grad_output.mul(self.constant).mul_(a.pow(self.constant-1))
class Negate(Function):
def forward(self, i):
return i.neg()
def backward(self, grad_output):
return grad_output.neg()
@@ -0,0 +1,30 @@
from ..variable import Variable
from ..function import Function
class Exp(Function):
def forward(self, i):
self.result = i.exp()
return self.result
def backward(self, grad_output):
return self.result * grad_output
class Log(Function):
def forward(self, i):
self.input = i
return i.log()
def backward(self, grad_output):
return grad_output.div(self.input)
class Log1p(Function):
def forward(self, i):
self.input = i
return i.log1p()
def backward(self, grad_output):
return grad_output.div(self.input.add(1))
Oops, something went wrong.

0 comments on commit 53f00ae

Please sign in to comment.