Skip to content

Commit

Permalink
Merge 82c6525 into a46a96c
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrodahl committed Mar 1, 2020
2 parents a46a96c + 82c6525 commit 9a4370f
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 209 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ version = "0.1.0"
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"

[targets]
test = ["Test"]
1 change: 1 addition & 0 deletions src/DNC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DNC

using Flux
using LinearAlgebra
using Parameters
using Zygote

export contentaddress
Expand Down
70 changes: 35 additions & 35 deletions src/access.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,75 @@
using Zygote: @adjoint


mutable struct State{A<:AbstractArray, V<:AbstractArray}
@with_kw mutable struct State{A<:AbstractArray, V<:AbstractArray}
L::Matrix
p::V
u::V
ww::V
wr::A
w_w::V
w_r::A
end

State(N::Int, R::Int) = State(
zeros(N, N),
zeros(N),
zeros(N),
zeros(N),
[zeros(N) for i in 1:R]
L=zeros(N, N),
p = zeros(N),
u = zeros(N),
w_w = zeros(N),
w_r = [zeros(N) for i in 1:R]
)

struct WriteHead{A<:AbstractArray, T<:Real}
@with_kw struct WriteHead{A<:AbstractArray, T<:Real}
k::A # Write key
β::T # Key strength
e::A # erase
v::A # add
ga::T # allocation gate
gw::T # write gate
g_a::T # allocation gate
g_w::T # write gate
end

struct ReadHead{A<:AbstractArray, T<:Real}
@with_kw struct ReadHead{A<:AbstractArray, T<:Real}
k::A # read key
β::T # key strength
f::T # free gate
π::A # read mode
end

# L should be updated before this
function readmem(M, rh::ReadHead, L::Matrix, prev_wr)
function readmem(M, rh::ReadHead, L::Matrix, prev_w_r)
@unpack k, β, π = rh
cr = contentaddress(k, M, β)
b = backwardweight(L, prev_wr)
f = forwardweight(L, prev_wr)
wr = readweight(b, cr, f, π)
r = M' * wr
c_r = contentaddress(k, M, β)
b = backwardweight(L, prev_w_r)
f = forwardweight(L, prev_w_r)
w_r = readweight(b, c_r, f, π)
r = M' * w_r
r
end

function writemem(M,
wh::WriteHead,
free::AbstractArray,
prev_ww::AbstractArray,
prev_wr::AbstractArray,
prev_w_w::AbstractArray,
prev_w_r::AbstractArray,
prev_usage::AbstractArray)
k, β, ga, gw, e, v = wh.k, wh.β, wh.ga, wh.gw, wh.e, wh.v
cw = contentaddress(k, M, β)
𝜓 = memoryretention(prev_wr, free)
u = usage(prev_usage, prev_ww, 𝜓)
@unpack k, β, g_a, g_w, e, v = wh
c_w = contentaddress(k, M, β)
𝜓 = memoryretention(prev_w_r, free)
u = usage(prev_usage, prev_w_w, 𝜓)
a = allocationweighting(u)
ww = writeweight(cw, a, gw, ga)
newmem = eraseandadd(M, ww, e, v)
w_w = writeweight(c_w, a, g_w, g_a)
newmem = erase_and_add(M, w_w, e, v)
newmem
end

function update_state_after_write!(state::State, M, wh::WriteHead, free::AbstractArray)
cw = contentaddress(wh.k, M, wh.β)
𝜓 = memoryretention(state.wr, free)
u = usage(state.u, state.ww, 𝜓)
c_w = contentaddress(wh.k, M, wh.β)
𝜓 = memoryretention(state.w_r, free)
u = usage(state.u, state.w_w, 𝜓)
a = allocationweighting(u)
ww = writeweight(cw, a, wh.gw, wh.ga)
w_w = writeweight(c_w, a, wh.g_w, wh.g_a)
state.u = u
state.ww = ww
updatelinkmatrix!(state.L, state.p, state.ww)
state.p = precedenceweight(state.p, state.ww)
state.w_w = w_w
updatelinkmatrix!(state.L, state.p, state.w_w)
state.p = precedenceweight(state.p, state.w_w)
end

@adjoint update_state_after_write!(state::State, M, wh::WriteHead, free::AbstractArray) =
Expand All @@ -83,11 +83,11 @@ function update_state_after_read!(state::State, M, rhs::AbstractArray)
wr = readweight(b, cr, f, rh.π)
wr
end
state.wr = [new_wr(state.L, state.wr[i], M, rhs[i]) for i in 1:length(rhs)]
state.w_r = [new_wr(state.L, state.w_r[i], M, rhs[i]) for i in 1:length(rhs)]
state
end

@adjoint update_state_after_read!(state::State, M, rhs::AbstractArray) =
update_state_after_read!(state, M, rhs), _ -> nothing

eraseandadd(M, ww, e, a) = M .* (ones(size(M)) - ww * e') + ww * a'
erase_and_add(M, w_w, e, a) = M .* (ones(size(M)) - w_w * e') + w_w * a'
52 changes: 26 additions & 26 deletions src/addressing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,50 @@ Compute the similarity K (default cosine similarity) between all rows of memory
function contentaddress(key, M, β, K=cosinesim)
r, c = size(M)
xs = [K(key, M[row,:]) for row in 1:r]
weightedsoftmax(xs, β)
weighted_softmax(xs, β)
end

# Single read head
function memoryretention(readweights::AbstractArray{<:Number, 1}, freegate)
return ones(length(readweights)) .- freegate.*readweights
function memoryretention(read_weights::AbstractArray{<:Number, 1}, free_gate)
return ones(length(read_weights)) .- free_gate.*read_weights
end

# Multiple read heads
function memoryretention(readweights::Array{<:AbstractArray, 1}, freegate)
R = length(readweights)
rs = [ones(length(readweights[i])) .- freegate[i].*readweights[i] for i in 1:R]
function memoryretention(read_weights::Array{<:AbstractArray, 1}, free_gate)
R = length(read_weights)
rs = [ones(length(read_weights[i])) .- free_gate[i].*read_weights[i] for i in 1:R]
foldl(rs) do x, y
x.*y
end
end

usage(u_prev, writeweights, 𝜓) = (u_prev + writeweights - (u_prev.*writeweights)) .* 𝜓
usage(u_prev, write_weights, 𝜓) = (u_prev + write_weights - (u_prev.*write_weights)) .* 𝜓

const _EPSILON = 1e-6


cumprodexclusive(arr::AbstractArray) = cumprod(arr) ./ arr
cumprod_exclusive(arr::AbstractArray) = cumprod(arr) ./ arr

function allocationweighting(u::AbstractArray; eps::AbstractFloat=_EPSILON)
u = eps .+ (1 - eps) .* u # Ensure values are large enough for numerical stability in cumprodexclusive
u = eps .+ (1 - eps) .* u # Ensure values are large enough for numerical stability in cumprod_exclusive
N = length(u)
ϕ = sortperm(u)
sortedusage = u[ϕ]
prodsortedusage = cumprodexclusive(sortedusage)
sortedalloc = (1 .- sortedusage) .* prodsortedusage
prod_sortedusage = cumprod_exclusive(sortedusage)
sortedalloc = (1 .- sortedusage) .* prod_sortedusage
a = sortedalloc[ϕ]
a
end

function allocationweighting(freegate, prev_wr, prev_ww, prev_usage; eps::AbstractFloat=_EPSILON)
𝜓 = memoryretention(prev_wr, freegate)
u = usage(prev_usage, prev_ww, 𝜓)
function allocationweighting(free_gate, prev_w_r, prev_w_w, prev_usage; eps::AbstractFloat=_EPSILON)
𝜓 = memoryretention(prev_w_r, free_gate)
u = usage(prev_usage, prev_w_w, 𝜓)
allocationweighting(u)
end

function allocationweighting(freegate, state::State; eps::AbstractFloat=_EPSILON)
wr, ww, u = state.wr, state.ww, state.u
allocationweighting(freegate, wr, ww, u)
function allocationweighting(free_gate, state::State; eps::AbstractFloat=_EPSILON)
@unpack w_r, w_w, u = state
allocationweighting(free_gate, w_r, w_w, u)
end

using Zygote: @adjoint
Expand All @@ -62,33 +62,33 @@ using Zygote: @adjoint
@adjoint allocationweighting(u::AbstractArray; eps=_EPSILON) =
allocationweighting(u; eps=eps), Δ -> (Δ, Δ)

function writeweight(cw, a, gw, ga)
return gw*(ga.*(a) + (1-ga)cw)
function writeweight(c_w, a, g_w, g_a)
return g_w*(g_a.*(a) + (1-g_a)c_w)
end

precedenceweight(p_prev, ww) = (1-sum(ww))*p_prev + ww
precedenceweight(p_prev, w_w) = (1-sum(w_w))*p_prev + w_w

function updatelinkmatrix!(L, precedence, ww)
function updatelinkmatrix!(L, precedence, w_w)
N, _ = size(L)
for i in 1:N
for j in 1:N
if i != j
L[i, j] = (1 - ww[i] - ww[j]) * L[i, j] + ww[i]*precedence[j]
L[i, j] = (1 - w_w[i] - w_w[j]) * L[i, j] + w_w[i]*precedence[j]
end
end
end
L
end

forwardweight(L, wr) = L*wr
backwardweight(L, wr) = L'*wr
forwardweight(L, w_r) = L*w_r
backwardweight(L, w_r) = L'*w_r

"""
readweight(backw, content, forw, read_mode)
Interpolate the backward weighting, content weighting and forward weighting.
read_mode is a vector of size 3 summing to 1.
"""
function readweight(backw, content, forw, readmode)
return readmode[1]*backw + readmode[2]*content + readmode[3]*forw
function readweight(backw, content, forw, read_mode)
return read_mode[1]*backw + read_mode[2]*content + read_mode[3]*forw
end
8 changes: 4 additions & 4 deletions src/computer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ DNCCell(in::Int, out::Int, N::Int, W::Int, R::Int; init=Flux.glorot_uniform) =
)

function (m::DNCCell)(h, x)
L, ww, wr, u = m.state.L, m.state.ww, m.state.wr, m.state.u
@unpack L, w_w, w_r, u = m.state
numreads = m.R
out = m.controller([x;h])
v = out[1:m.Y]
ξ = out[m.Y+1:length(out)]
rhs, wh = splitparams(ξ, numreads, m.W)
rhs, wh = split_ξ(ξ, numreads, m.W)
freegate = [rh.f for rh in rhs]
m.M = writemem(m.M, wh, freegate, ww, wr, u)
m.M = writemem(m.M, wh, freegate, w_w, w_r, u)
update_state_after_write!(m.state, m.M, wh, freegate)
r = [readmem(m.M, rh, L, wr[1]) for rh in rhs]
r = [readmem(m.M, rh, L, w_r[1]) for rh in rhs]
r = vcat(r...)
update_state_after_read!(m.state, m.M, rhs)
m.readvectors = r # Flatten list of lists
Expand Down
64 changes: 43 additions & 21 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,67 @@ using LinearAlgebra

cosinesim(u, v) = dot(u, v)/(norm(u)*norm(v))

weightedsoftmax(xs, weight) = softmax(xs.*weight)
weighted_softmax(xs, weight) = softmax(xs.*weight)

oneplus(x) = 1 + log(1+exp(x))

inputsize(X::Int, R::Int, W::Int) = X + R * W

outputsize(R::Int, N::Int, W::Int, X::Int, Y::Int) = W*R + 3W + 5R +3 + Y

function calcoutput(v, r, Wr)
return v .+ Wr*r
function calcoutput(v, r, W_r)
return v .+ W_r*r
end

function splitparams(ξ, R::Int, W::Int)
#"""
#Assuming R=1. Otherwise, the code will break
#"""
#function split_ξ(ξ, W::Int)
# R = 1
# length(ξ) != (W*R)+3W+5R+3 &&
# error("Length of xi-vector is incorrect. Expected $((W*R)+3W+5R+3), got $(length(ξ))")
# kr = ξ[1:W]
# βr = ξ[W+1]
# kw = ξ[(W+2):(2W+1)]
# βw = ξ[2W+2]
# ê = ξ[(2W+3):(3W+2)]
# v = ξ[(3W+3):(4W+2)]
# f̂ = ξ[4W+3]
# ĝa = ξ[4W+4]
# ĝw = ξ[4W+5]
# readmode = ξ[(4W+6):length(ξ)]
# rh = ReadHead(kr, βr, σ(f̂), softmax(readmode))
# wh = WriteHead(kw, βw, σ.(ê), v, σ(ĝa), σ(ĝw))
# return (rh, wh)
#end
#

function split_ξ(ξ, R::Int, W::Int)
length(ξ) != (W*R)+3W+5R+3 &&
error("Length of xi-vector is incorrect. Expected $((W*R)+3W+5R+3), got $(length(ξ))")
# read keys
kr = [ξ[((r-1)*W+1):r*W] for r in 1:R]
βr = ξ[(R*W+1):(R*W+R)]
kw = ξ[(R*W+1+R):(R*W+R+W)]
βw = ξ[(R*W+R+W+1)]
k_r = [ξ[((r-1)*W+1):r*W] for r in 1:R]
β_r = ξ[(R*W+1):(R*W+R)]
k_w = ξ[(R*W+1+R):(R*W+R+W)]
β_w = ξ[(R*W+R+W+1)]
= ξ[(R*W+R+W+2):(R*W+R+2W+1)]
v = ξ[(R*W+R+2W+2):(R*W+R+3W+1)]
= ξ[(R*W+R+3W+2):(R*W+2R+3W+1)]
ĝa = ξ[(R*W+2R+3W+2)]
ĝw = ξ[(R*W+2R+3W+3)]
ĝ_a = ξ[(R*W+2R+3W+2)]
ĝ_w = ξ[(R*W+2R+3W+3)]
rest = ξ[(R*W+2R+3W+4):length(ξ)]
readmode = [rest[((r-1)*3+1):3r] for r in 1:R]
rhs = [ReadHead(
kr[i],
βr[i],
σ(f̂[i]),
Flux.softmax(readmode[i])) for i in 1:R]
k=k_r[i],
β=β_r[i],
f=σ(f̂[i]),
π=Flux.softmax(readmode[i])) for i in 1:R]
wh = WriteHead(
kw,
βw,
σ.(ê),
v,
σ(ĝa),
σ(ĝw)
k=k_w,
β=β_w,
e=σ.(ê),
v=v,
g_a = σ(ĝ_a),
g_w = σ(ĝ_w)
)
return (rhs, wh)
end
Loading

0 comments on commit 9a4370f

Please sign in to comment.