In [1]:
using Flux,LinearAlgebra

In [2]:
function clone(m::Any, N::Int)
    return ([deepcopy(m) for i in 1:N])
end

mutable struct MultiHeadedAttention
        d_model::Int
        h::Int
        d_k::Int
        linears::Array
        dropout::Dropout
        attn::Union{Array,Nothing}
        MultiHeadedAttention(d_model::Int,h::Int,dropout::Float64) = new(
        d_model,h,Int(d_model/h),clone(Dense(d_model, d_model), 4),Dropout(dropout),nothing
    )

end
function (mh::MultiHeadedAttention)( query::Array, key::Array, value::Array, mask::Union{Array,Nothing})
    
    # Same mask applied to all h heads.
    if mask != nothing
        mask=unsqueeze(mask,2)
    end
    
     nbatches = size(query)[1]
    
    # 1) Do all the linear projections in batch from d_model => h x d_k 
     query, key, value = [permutedims(reshape(l(x),nbatches, :, mh.h, mh.d_k),(1,3,2,4))
        for (l, x) in zip(mh.linears, (query, key, value))]
    
    # 2) Apply attention on all the projected vectors in batch. 
    x, mh.attn = attention(query, key, value, mask, mh.dropout)
    
    # 3) "Concat" using a view and apply a final linear. 
    x = reshape(permutedims(x,(1,3,2,4)) ,nbatches, :, mh.h * mh.d_k)
    
    return mh.linears[end](x)
end

mutable struct PositionwiseFeedForward
    w_1::Dense
    w_2::Dense
    dropout::Dropout
    
    PositionwiseFeedForward(d_model::Int,d_ff::Int,dropout::Float64) = new(Dense(d_model,d_ff),Dense(d_ff,d_model),Dropout(dropout))

end

function (pl::PositionwiseFeedForward)(x::Array)
    return pl.w_2( pl.dropout( relu(pl.w_1(x) )))
end

mutable struct Sublayer 
    size::Int
    norm::LayerNorm
    dropout::Dropout
    Sublayer(size::Int,dropout::Float64) = new(size,LayerNorm(size),Dropout(dropout))
end

function (sb::Sublayer)(x::Matrix,f::Function)
   return x .+ dropout(f(norm(x)))
end

mutable struct EncoderLayer 
    size::Int
    self_attn::MultiHeadedAttention
    feed_forward::PositionwiseFeedForward
    sublayer::Array
    EncoderLayer(size::Int,
        self_attn::MultiHeadedAttention,
        feed_forward::PositionwiseFeedForward,
        dropout::Float64) = new(size,self_attn,feed_forward,clone(Sublayer(size,dropout),2))
end

function (el::EncoderLayer)(x::Matrix,src_mask::Matrix)
    x = el.sublayer[1](x, (x,src_mask)->el.self_attn(x, x, x, src_mask))
    return el.sublayer[2](x, el.feed_forward)
end

mutable struct Encoder 
    layers::Array
    norm::LayerNorm
    Encoder(layer::EncoderLayer,N::Int) = new(clone(layer,N),LayerNorm(layer.size))     
end

function (en::Encoder)(x::Matrix,src_mask::Matrix)
      for layer in en.layers
            x = layer(x, src_mask)
      end
      return self.norm(x)
end

mutable struct DecoderLayer
    size::Int
    self_attn::MultiHeadedAttention
    src_attn::MultiHeadedAttention
    feed_forward::PositionwiseFeedForward
    sublayer::Array
    
    DecoderLayer(size::Int,
        self_attn::MultiHeadedAttention,
        src_attn::MultiHeadedAttention,
        feed_forward::PositionwiseFeedForward,
        dropout::Float64) = new(size,self_attn,src_attn,feed_forward,clone(Sublayer(size,dropout),3))
end

function (dl::DecoderLayer)(x::Matrix,memory::Matrix,src_mask::Matrix,tgt_mask::Matrix)
    x = dl.sublayer[1](x, (x,tgt_mask)->dl.self_attn(x, x, x, tgt_mask))
    x = dl.sublayer[2](x, (x,src_mask,memory)->dl.src_attn(x, memory, memory, src_mask))
    return dl.sublayer[3](x, el.feed_forward)
end


mutable struct Decoder 
    layers::Array
    norm::LayerNorm
    Decoder(layer::DecoderLayer,N::Int) = new(clone(layer,N),LayerNorm(layer.size))  
end

function (dc::Decoder)(x::Matrix,memory::Matrix,src_mask::Matrix,tgt_mask::Matrix)
      for layer in dc.layers
            x = layer(x, memory,src_mask,tgt_mask)
      end
      return self.norm(x)
end


function subsequent_mask(size::Int)
    "Mask out subsequent positions."
    return  Array(LowerTriangular(ones(size,size)).==1.0)
end

function fill_mask!(a::Array,mask::Array,esp=-1e9)
    a[.~mask].=esp
    return a
end

function swap_last_2_dimesions!(a::Array)
    #swap last 2 dimensions
    dimesions=Array(1:length(size(a)))
    swp_d_1=dimesions[end]
    swp_d_2=dimesions[end-1]
    dimesions=append!(dimesions[1:end-2],[swp_d_1,swp_d_2])
    return permutedims(a,dimesions)
end

mutable struct SourceEmbedding
    m::Chain
end
function (se::SourceEmbedding)(x::Array)
    return m(x)
end

mutable struct TargetEmbedding
    m::Chain
end

function (se::TargetEmbedding)(x::Array)
    return m(x)
end

struct Generator
    m::Chain    
    Generator(d_model::Int,vocab::Int)=new(Chain(Dense(d_model,d_model),logsoftmax))
end

function (g::Generator)(x::Array)
    return g(x)
end

mutable struct EncoderDecoder
    encoder::Encoder
    decoder::Decoder
    src_embed::SourceEmbedding
    tgt_embed::TargetEmbedding
    generator::Generator
    
    EncoderDecoder(
    encoder::Encoder,
    decoder::Decoder,
    src_embed::SourceEmbedding,
    tgt_embed::TargetEmbedding,
    generator::Generator) =  new(encoder,decoder,src_embed,tgt_embed,generator)
end

function (ed::EncoderDecoder)(src::Matrix, tgt::Matrix, src_mask::Matrix, tgt_mask::Matrix)
    return ed.decoder(ed.tgt_embed(tgt),ed.encoder(ed.src_embed(src),src_mask),src_mask,tgt_mask)
end

In [3]:
function attention(query, key, value, mask::Union{Array,Nothing}, dropout::Union{Array,Nothing})
    "Compute 'Scaled Dot Product Attention'"
    d_k = size(query)[end]
    scores = (query * swap_last_2_dimesions!(key))./math.sqrt(d_k)
    if mask != nothing
        scores = masked_fill!(scores,mask)
    end
    p_attn = softmax(scores, dim = length(size(scores)))
    
    if dropout != nothing
        p_attn = dropout(p_attn)
    end
    return p_attn*value, p_attn
end

attention (generic function with 1 method)

In [4]:
#unsqueeze
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

unsqueeze (generic function with 1 method)

In [5]:
mutable struct EmbeddingLayer
   W
   EmbeddingLayer(mf, vs) = new(Flux.glorot_normal(mf, vs))
end

(m::EmbeddingLayer)(x) = m.W * Flux.onehotbatch(reshape(x, pad_size*N), 0:vocab_size-1)

In [6]:
mutable struct Embeddings

        lut::EmbeddingLayer
        d_model::Int

    Embeddings(d_model::Int, vocab::Int)=new(EmbeddingLayer(d_model,vocab),d_model)
    
end

function(emb::Embeddings)(x::Array)
    return emb(x) .* sqrt(d_model.d_model)
end


In [7]:
function init_pe(max_len::Int, d_model::Int)
    pe=zeros(max_len,d_model)
    pos = unsqueeze(Array(range(0,length=max_len,step=1)),2)
    div_term = exp.(Array(range(0,length=Int(d_model/2),step=2)).*-(log(10000.0) / d_model))
    c=pos.* transpose(div_term)
    pe[:, range(2,length=Int(d_model/2),step=2)] = cos.(c)
    pe[:, range(1,length=Int(d_model/2),step=2)] = sin.(c)
    #pe =unsqueeze(pe,1)
    return pe
    
end
struct PositionalEncoding
     d_model::Int
     dropout::Dropout
     max_len::Int
     pe::Array
    PositionalEncoding(d_model::Int,p_dropout::Float64,max_len::Int) = new(d_model,Dropout(p_dropout),max_len,init_pe(max_len,d_model))
end

function (pe::PositionalEncoding)(x::Array)
    return dropout(x.+pe.pe[:,size(x)[2]])
end

In [12]:
function make_model(src_vocab::Int, tgt_vocab::Int, N::Int, 
               d_model::Int, d_ff::Int, h::Int, dropout::Float64)
    "Helper: Construct a model from hyperparameters."
    c = deepcopy
    attn = MultiHeadedAttention(d_model,h,dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout,5000)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        SourceEmbedding(Chain(Embeddings(d_model, src_vocab), c(position))),
        TargetEmbedding(Chain(Embeddings(d_model, tgt_vocab), c(position))),
        Generator(d_model, tgt_vocab))
    
   
    # Initialize parameters with Glorot / fan_avg.

    return model
end

make_model (generic function with 1 method)

In [13]:
tmp_model = make_model(10, 10, 2,512,2048,8,0.1)

EncoderDecoder(Encoder(EncoderLayer[EncoderLayer(512, MultiHeadedAttention(512, 8, 64, Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}[Dense(512, 512), Dense(512, 512), Dense(512, 512), Dense(512, 512)], Dropout(0.1), nothing), PositionwiseFeedForward(Dense(512, 2048), Dense(2048, 512), Dropout(0.1)), Sublayer[Sublayer(512, LayerNorm(512), Dropout(0.1)), Sublayer(512, LayerNorm(512), Dropout(0.1))]), EncoderLayer(512, MultiHeadedAttention(512, 8, 64, Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}[Dense(512, 512), Dense(512, 512), Dense(512, 512), Dense(512, 512)], Dropout(0.1), nothing), PositionwiseFeedForward(Dense(512, 2048), Dense(2048, 512), Dropout(0.1)), Sublayer[Sublayer(512, LayerNorm(512), Dropout(0.1)), Sublayer(512, LayerNorm(512), Dropout(0.1))])], LayerNorm(512)), Decoder(DecoderLayer[DecoderLayer(512, MultiHeadedAttention(512, 8, 64, Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}[Dense(512, 512), Dense(512, 512), Dense(512, 512), Dense(512