Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
add gather and scatter (#162)
  • Loading branch information
szagoruyko authored and alexbw committed Oct 23, 2016
1 parent e93e2a1 commit be47e50
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/gradfuns.lua
Expand Up @@ -549,6 +549,16 @@ overload.module("torch", torch, function(module)
-- function(g, ans, x, size, dim) return nil end,

-- })

module.gradient("gather", {
function(g, ans, x, dim, index) return torch.scatter(util.zerosLike(x), dim, index, g) end,
})
module.gradient("scatter", {
function(g, ans, x, dim, index, val) return nil end,
function(g, ans, x, dim, index, val) return nil end,
function(g, ans, x, dim, index, val) return nil end,
function(g, ans, x, dim, index, val) return torch.gather(g, dim, index) end,
})

module.gradient("bmm", {
function(g, ans, x, y) return torch.bmm(g, torch.transpose(y, 3, 2)) end,
Expand Down
21 changes: 21 additions & 0 deletions test/test.lua
Expand Up @@ -1908,6 +1908,27 @@ local tests = {
tester:assert(gradcheck(bmmFn, {X=X, Y=Y}), "Incorrect gradient")
end,

Gather = function()
local X = torch.randn(5,5)
local index = torch.LongTensor{{1, 2, 3, 4, 5}, {2, 3, 4, 5, 1}}

local gather = function(inputs, index)
return torch.sum(torch.gather(inputs.X, 1, index))
end
tester:assert(gradcheck(gather, {X = X}, index), "Incorrect gradient")
end,

Scatter = function()
local X = torch.rand(2, 5)
local index = torch.LongTensor{{1, 2, 3, 1, 1}, {3, 1, 1, 2, 3}}
local Z = torch.zeros(3, 5)

local scatter = function(inputs, index, Z)
return torch.sum(torch.scatter(Z, 1, index, inputs.X))
end
tester:assert(gradcheck(scatter, {X = X}, index, Z), "Incorrect gradient")
end,

Baddbmm = function()
local v1 = torch.randn(1)[1]
local v2 = torch.randn(1)[1]
Expand Down

0 comments on commit be47e50

Please sign in to comment.