In [34]:
using Flux,LinearAlgebra,Distances,StatsBase

┌ Info: Precompiling StatsBase [2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91]
└ @ Base loading.jl:1278


In [None]:
function clone(m::Any, N::Int)
    return ([deepcopy(m) for i in 1:N])
end

mutable struct MultiHeadedAttention
        d_model::Int
        h::Int
        d_k::Int
        linears::Array
        dropout::Dropout
        attn::Union{Array,Nothing}
        MultiHeadedAttention(d_model::Int,h::Int,dropout::Float64) = new(
        d_model,h,Int(d_model/h),clone(Dense(d_model, d_model), 4),Dropout(dropout),nothing
    )

end
function (mh::MultiHeadedAttention)( query::Array, key::Array, value::Array, mask::Union{Array,Nothing})
    
    # Same mask applied to all h heads.
    if mask != nothing
        mask=unsqueeze(mask,2)
    end
    
     nbatches = size(query)[1]
    
    # 1) Do all the linear projections in batch from d_model => h x d_k 
     query, key, value = [permutedims(reshape(l(x),nbatches, :, mh.h, mh.d_k),(1,3,2,4))
        for (l, x) in zip(mh.linears, (query, key, value))]
    
    # 2) Apply attention on all the projected vectors in batch. 
    x, mh.attn = attention(query, key, value, mask, mh.dropout)
    
    # 3) "Concat" using a view and apply a final linear. 
    x = reshape(permutedims(x,(1,3,2,4)) ,nbatches, :, mh.h * mh.d_k)
    
    return mh.linears[end](x)
end

mutable struct PositionwiseFeedForward
    w_1::Dense
    w_2::Dense
    dropout::Dropout
    
    PositionwiseFeedForward(d_model::Int,d_ff::Int,dropout::Float64) = new(Dense(d_model,d_ff),Dense(d_ff,d_model),Dropout(dropout))

end

function (pl::PositionwiseFeedForward)(x::Array)
    return pl.w_2( pl.dropout( relu(pl.w_1(x) )))
end

mutable struct Sublayer 
    size::Int
    norm::LayerNorm
    dropout::Dropout
    Sublayer(size::Int,dropout::Float64) = new(size,LayerNorm(size),Dropout(dropout))
end

function (sb::Sublayer)(x::Matrix,f::Function)
   return x .+ dropout(f(norm(x)))
end

mutable struct EncoderLayer 
    size::Int
    self_attn::MultiHeadedAttention
    feed_forward::PositionwiseFeedForward
    sublayer::Array
    EncoderLayer(size::Int,
        self_attn::MultiHeadedAttention,
        feed_forward::PositionwiseFeedForward,
        dropout::Float64) = new(size,self_attn,feed_forward,clone(Sublayer(size,dropout),2))
end

function (el::EncoderLayer)(x::Matrix,src_mask::Matrix)
    x = el.sublayer[1](x, (x,src_mask)->el.self_attn(x, x, x, src_mask))
    return el.sublayer[2](x, el.feed_forward)
end

mutable struct Encoder 
    layers::Array
    norm::LayerNorm
    Encoder(layer::EncoderLayer,N::Int) = new(clone(layer,N),LayerNorm(layer.size))     
end

function (en::Encoder)(x::Matrix,src_mask::Matrix)
      for layer in en.layers
            x = layer(x, src_mask)
      end
      return self.norm(x)
end

mutable struct DecoderLayer
    size::Int
    self_attn::MultiHeadedAttention
    src_attn::MultiHeadedAttention
    feed_forward::PositionwiseFeedForward
    sublayer::Array
    
    DecoderLayer(size::Int,
        self_attn::MultiHeadedAttention,
        src_attn::MultiHeadedAttention,
        feed_forward::PositionwiseFeedForward,
        dropout::Float64) = new(size,self_attn,src_attn,feed_forward,clone(Sublayer(size,dropout),3))
end

function (dl::DecoderLayer)(x::Matrix,memory::Matrix,src_mask::Matrix,tgt_mask::Matrix)
    x = dl.sublayer[1](x, (x,tgt_mask)->dl.self_attn(x, x, x, tgt_mask))
    x = dl.sublayer[2](x, (x,src_mask,memory)->dl.src_attn(x, memory, memory, src_mask))
    return dl.sublayer[3](x, el.feed_forward)
end


mutable struct Decoder 
    layers::Array
    norm::LayerNorm
    Decoder(layer::DecoderLayer,N::Int) = new(clone(layer,N),LayerNorm(layer.size))  
end

function (dc::Decoder)(x::Matrix,memory::Matrix,src_mask::Matrix,tgt_mask::Matrix)
      for layer in dc.layers
            x = layer(x, memory,src_mask,tgt_mask)
      end
      return self.norm(x)
end


function subsequent_mask(size::Int)
    "Mask out subsequent positions."
    return  unsqueeze(Array(LowerTriangular(ones(size,size)).==1.0),1)
end

function fill_mask!(a::Array,mask::Array,esp=-1e9)
    a[.~mask].=esp
    return a
end

function swap_last_2_dimesions!(a::Array)
    #swap last 2 dimensions
    dimesions=Array(1:length(size(a)))
    swp_d_1=dimesions[end]
    swp_d_2=dimesions[end-1]
    dimesions=append!(dimesions[1:end-2],[swp_d_1,swp_d_2])
    return permutedims(a,dimesions)
end

mutable struct SourceEmbedding
    m::Chain
end
function (se::SourceEmbedding)(x::Array)
    return m(x)
end

mutable struct TargetEmbedding
    m::Chain
end

function (se::TargetEmbedding)(x::Array)
    return m(x)
end

struct Generator
    m::Chain    
    Generator(d_model::Int,vocab::Int)=new(Chain(Dense(d_model,d_model),logsoftmax))
end

function (g::Generator)(x::Array)
    return g(x)
end

mutable struct EncoderDecoder
    encoder::Encoder
    decoder::Decoder
    src_embed::SourceEmbedding
    tgt_embed::TargetEmbedding
    generator::Generator
    
    EncoderDecoder(
    encoder::Encoder,
    decoder::Decoder,
    src_embed::SourceEmbedding,
    tgt_embed::TargetEmbedding,
    generator::Generator) =  new(encoder,decoder,src_embed,tgt_embed,generator)
end

function encode(ed::EncoderDecoder,src::Matrix,src_mask::Matrix)
    ed.encoder(ed.src_embed(src,src_mask),src_mask)
end

function decode(ed::EncoderDecoder,memory::Matrix,src_mask::Matrix,tgt::Matrix,tgt_mask::Matrix)
    return ed.decoder(ed.tgt_embed(tgt),memory,src_mask,tgt_mask)
end


function (ed::EncoderDecoder)(src::Matrix, tgt::Matrix, src_mask::Matrix, tgt_mask::Matrix)
    return decode(ed,encode(ed,src,src_mask),src_mask,tgt, tgt_mask)
end




In [None]:
function attention(query, key, value, mask::Union{Array,Nothing}, dropout::Union{Array,Nothing})
    "Compute 'Scaled Dot Product Attention'"
    d_k = size(query)[end]
    scores = (query * swap_last_2_dimesions!(key))./math.sqrt(d_k)
    if mask != nothing
        scores = masked_fill!(scores,mask)
    end
    p_attn = softmax(scores, dim = length(size(scores)))
    
    if dropout != nothing
        p_attn = dropout(p_attn)
    end
    return p_attn*value, p_attn
end

In [None]:
#unsqueeze
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

In [None]:
mutable struct EmbeddingLayer
   W
   EmbeddingLayer(mf, vs) = new(Flux.glorot_normal(mf, vs))
end

(m::EmbeddingLayer)(x) = m.W * Flux.onehotbatch(reshape(x, pad_size*N), 0:vocab_size-1)

In [None]:
mutable struct Embeddings

        lut::EmbeddingLayer
        d_model::Int

    Embeddings(d_model::Int, vocab::Int)=new(EmbeddingLayer(d_model,vocab),d_model)
    
end

function(emb::Embeddings)(x::Array)
    return emb(x) .* sqrt(d_model.d_model)
end


In [None]:
function init_pe(max_len::Int, d_model::Int)
    pe=zeros(max_len,d_model)
    pos = unsqueeze(Array(range(0,length=max_len,step=1)),2)
    div_term = exp.(Array(range(0,length=Int(d_model/2),step=2)).*-(log(10000.0) / d_model))
    c=pos.* transpose(div_term)
    pe[:, range(2,length=Int(d_model/2),step=2)] = cos.(c)
    pe[:, range(1,length=Int(d_model/2),step=2)] = sin.(c)
    #pe =unsqueeze(pe,1)
    return pe
    
end
struct PositionalEncoding
     d_model::Int
     dropout::Dropout
     max_len::Int
     pe::Array
    PositionalEncoding(d_model::Int,p_dropout::Float64,max_len::Int) = new(d_model,Dropout(p_dropout),max_len,init_pe(max_len,d_model))
end

function (pe::PositionalEncoding)(x::Array)
    return dropout(x.+pe.pe[:,size(x)[2]])
end

In [None]:
function make_model(src_vocab::Int, tgt_vocab::Int, N::Int, 
               d_model::Int, d_ff::Int, h::Int, dropout::Float64)
    "Helper: Construct a model from hyperparameters."
    c = deepcopy
    attn = MultiHeadedAttention(d_model,h,dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout,5000)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        SourceEmbedding(Chain(Embeddings(d_model, src_vocab), c(position))),
        TargetEmbedding(Chain(Embeddings(d_model, tgt_vocab), c(position))),
        Generator(d_model, tgt_vocab))
    
   
    # Initialize parameters with Glorot / fan_avg.

    return model
end

In [None]:
tmp_model = make_model(10, 10, 2,512,2048,8,0.1)

In [None]:
function make_std_mask(tgt::Array, pad::Int)
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt .!= pad)
        tgt_mask = unsqueeze(tgt_mask,Array(1:size(tgt_mask)[2])[end])
        tgt_mask = tgt_mask .& (subsequent_mask(size(tgt)[end]))
        return tgt_mask
end

In [None]:
struct Batch
    src::Array
    src_mask::Array
    trg::Union{Array,Nothing}
    trg_mask::Union{Array,Nothing}
    ntokens::Int
    
    function Batch(src::Array, trg_i::Union{Array,Nothing}, pad::Int)
    src_mask = (src .!= pad)
    src_mask = unsqueeze(src_mask,Array(1:size(src_mask)[2])[end-1])
    trg_mask = nothing
    ntokens = 0
    if trg_i != nothing
            trg = trg_i[:, begin:end-1]
            trg_y = trg_i[:,2:end]
            trg_mask = make_std_mask(trg, pad)
            ntokens = sum(trg_y .!= pad)
    end
    new(src,src_mask,trg,trg_mask,ntokens)
end
        
end

# Training Loop

In [None]:
function run_epoch!(loss, ps, data, opt)
  ps = Flux.params(model)
 
  for batch in batches
    # back is a method that computes the product of the gradient so far with its argument.
    train_loss, back = Zygote.pullback(() -> loss_compute(model, batch,loss), ps)
    # Insert whatever code you want here that needs training_loss, e.g. logging.
    # logging_callback(training_loss)
    # Apply back() to the correct type of 1.0 to get the gradient of loss.
    gs = back(one(train_loss))
    # Insert what ever code you want here that needs gradient.
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping.
  end
end

In [37]:
mutable struct LabelSmoothing
    size::Int
    criterion::Function
    padding_idx::Int
    confidence::Float64
    smoothing::Float64
    true_dist::Union{Array,Nothing}
    LabelSmoothing(size::Int,padding_idx::Int,smoothing::Float64) = new(size,StatsBase.kldivergence,padding_idx,1-smoothing,smoothing)
end

In [39]:
macro assert(ex)
    return :( $ex ? nothing : throw(AssertionError($(string(ex)))) )
end

@assert (macro with 1 method)

In [40]:
function scatter_dim!(x::Array,target::Array,value::Float64)
    for i in 1:size(target)[1]
        index=target[i]
      
        x[i,index]=value
    end
    return x
end

scatter_dim! (generic function with 1 method)

In [50]:
function (ls::LabelSmoothing)(x::Array,target::Array)
    
    @assert size(x)[2] == ls.size
    true_dist=fill(ls.smoothing / (ls.size - 2), size(x))
    true_dist=scatter_dim!(true_dist,target,ls.confidence)
    true_dist[:,ls.padding_idx].=0
    mask = findall(x->x!=0, target.==ls.padding_idx)
    
    if size(mask)[1] > 0
        true_dist[mask,:].=0
    end
    ls.true_dist = true_dist
    return ls.criterion(x,true_dist)
end

In [None]:
mutable struct LossCompute
 generator::Generator
 criterion::LabelSmoothing
 LossCompute(generator::Generator,criterion::LabelSmoothing) = new(generator,criterion)
end

function(slc::LossCompute)(x::Array,y::Array,norm::Int)
 x = slc.generator(x)
 loss = slc.criterion(x,y) / norm
 return loss
end

In [None]:
function loss_function(model,batch,slc)
    out = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
    loss = slc(out, batch.trg_y, batch.ntokens)
end

In [None]:
# Train the simple copy task.
V = 11
criterion = LabelSmoothing(V, 0, 0.0)
model =  make_model(V, V, 2,512,2048,8,0.1) 
model_opt = Optimiser(ExpDecay(), ADAM())
no_of_epoch=10
slc =  LossCompute(model.generator,criterion)


data = data_gen(V, 30, 20)

for epoch in 1:no_of_epoch
    run_epoch!(slc,data,model,model_opt)
end

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

model.eval()
src = Variable(torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]]) )
src_mask = Variable(torch.ones(1, 1, 10) )
print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))

In [None]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)