In [1]:
using Knet: Knet, minibatch, param, param0, softmax, nll, RNN
using Pickle
using NPZ
using Statistics: mean
using DataStructures: OrderedDict

In [2]:
# Load data
dir = "..\\dataset\\representations\\uncond\\cp\\ailab17k_from-scratch_cp"
t2i, i2t = Pickle.load(open("$(dir)\\dictionary.pkl"))
train = NPZ.npzread("$(dir)\\train_data_linear.npz")
test = NPZ.npzread("$(dir)\\test_data_linear.npz")

toktids = OrderedDict("tempo"=>1,"chord"=>2,"bar-beat"=>3,"type"=>4,"pitch"=>5, "duration"=>6, "velocity"=>7)
n_tokens = [length(t2i[k]) for (k, v) in toktids]

t2i # For reference

Dict{Any, Any} with 7 entries:
  "tempo"    => Dict{Any, Any}("Tempo_131"=>12, "Tempo_92"=>53, "Tempo_176"=>24…
  "duration" => Dict{Any, Any}("Note_Duration_720"=>15, "Note_Duration_840"=>16…
  "pitch"    => Dict{Any, Any}("Note_Pitch_34"=>21, "Note_Pitch_86"=>73, "Note_…
  "chord"    => Dict{Any, Any}("D_/o7"=>69, "F_M7"=>105, "C_m"=>51, "C_o"=>53, …
  "bar-beat" => Dict{Any, Any}("Bar"=>1, "Beat_1"=>3, "Beat_4"=>12, "Beat_11"=>…
  "velocity" => Dict{Any, Any}("Note_Velocity_54"=>8, "Note_Velocity_82"=>22, "…
  "type"     => Dict{Any, Any}("EOS"=>0, "Metrical"=>1, "Note"=>2)

In [3]:
# Model settings
BATCH_SIZE = 25
EMBED_SIZES = [256, 256, 64, 32, 512, 128, 128]
D_MODEL = 512
N_HEAD = 8

8

In [4]:
# Minibatching
train_x = trunc.(Int, permutedims(
          cat(train["x"], reshape(train["mask"], (size(train["x"],1),size(train["x"],2),1)), dims=3),
          [2, 3, 1]).+1);                                           @show size(train_x) # T, K+1, B
train_y = trunc.(Int, permutedims(train["y"], [2, 3, 1]).+1);       @show size(train_y) # T, K, B
test_x = trunc.(Int, permutedims(
         cat(test["x"], reshape(test["mask"], (size(test["x"],1),size(test["x"],2),1)), dims=3),
         [2, 3, 1]).+1);                                            @show size(test_x) # T, K+1, B
test_y = trunc.(Int, permutedims(test["y"], [2, 3, 1]).+1);         @show size(test_y) # T, K, B

train_loader = minibatch(train_x, train_y, BATCH_SIZE; shuffle=true)
test_loader = minibatch(test_x, test_y, BATCH_SIZE; shuffle=true)

length.((train_loader, test_loader))

size(train_x) = 

(3584, 8, 1625)
size(train_y) = 

(3584, 7, 1625)
size(test_x) = (3584, 8, 50)
size(test_y) = (3584, 7, 50)


(65, 2)

# Model

In [5]:
# Simple useful layers

struct Linear; W; b; end
Linear(input::Int, output::Int) = Linear(param(output, input), param0(output))
(l::Linear)(x) = l.W*x .+ l.b

struct Embedding; W; end
Embedding(n_tokens::Int, embed::Int) = Embedding(param(embed, n_tokens))
(e::Embedding)(x) = e.W[:, x]

In [6]:
# TODO: Linear Transformer backbone
# ϕ(x) = elu(x) + 1 
# struct Transformer; W_Q; W_K; W_V; ff; end
# Transformer(n_layers, n_heads, q_dim, v_dim, ff_dim; activation=ϕ, dropout=0.1) =
#     Transformer(param(q_dim, ???), param(k_dim, ???), param(v_dim, ???), ff_dim)

In [7]:
# Sampling function for predicting tokens
function sampling(x; dims=1) # TODO : Temperature
    # size(x) = (N_tokens[type], B, T)
    x = softmax(x, dims=dims)
    first.(Tuple.(argmax(x, dims=dims))) # TODO: Weighted sampling
end

sampling (generic function with 1 method)

In [8]:
struct CPTransformer; embeds; lin_in; lin_transformer; projs; blend_type; end

CPTransformer(n_tokens::Vector{Int}, embed_sizes::Vector{Int}, d_model::Int, d_inner::Int;
    blend_dim=32) =
    CPTransformer([Embedding(n, e) for (n, e) in zip(n_tokens, embed_sizes)],
            Linear(sum(embed_sizes), d_model),
            RNN(d_model, d_model), # Placeholder until Transformer implementation
            [Linear(d_model, n) for n in n_tokens],
            Linear(d_model + blend_dim, d_model))

# y    => y != nothing ? [training mode] : [interference mode]
# gen  => gen ? return ŷ : return ŷ_P
function (model::CPTransformer)(x; y=nothing, gen=false)
    x, mask = x[:, 1:end-1, :], x[:, end, :];                    @show size(x) # (T, N_tkn, B)
    
    x = vcat([embed(x[:, i, :]) for (embed, i) in
            zip(model.embeds, 1:length(model.embeds))]...);      @show size(x) # (X_emb, T, B)
    
    x = cat([model.lin_in(x[:,i,:]) for i in
            1:size(x, 2)]..., dims=3);                           @show size(x) # (X_in, B, T)
    
    # x = Positional_Encoding(x) <-- TODO
    
    h = model.lin_transformer(x);                                @show size(h) # (D_m, B, T)
    
    ŷ_type_P = (cat([model.projs[toktids["type"]](h[:,:,i])
            for i in 1:size(h, 3)]..., dims=3));                 @show size(ŷ_type_P) # (N_tvoc, B, T)
    
    ŷ_type = y!=nothing ? y[:, toktids["type"], :] : 
            reshape(sampling(ŷ_type_P), (size(ŷ_type_P, 3), :)); @show size(ŷ_type) # (T, B)
               
    ŷ_τ = vcat([permutedims(h, [1,3,2]), 
            model.embeds[toktids["type"]](ŷ_type)]...);          @show size(ŷ_τ) # (D_m + blend, B, T)
    
    h_ = cat([model.blend_type(ŷ_τ[:, i, :]) for i in
            1:size(ŷ_τ, 2)]..., dims=3);                         @show size(h_) # (D_m, B, T)
    
    ŷ_P = [permutedims(i!=toktids["type"] ? 
            cat([proj(h[:,:,i]) for i in 1:size(h, 3)]..., dims=3) : 
            ŷ_type_P, [3,1,2]) for (proj,i)
            in zip(model.projs, 1:length(model.projs))];         @show size.(ŷ_P) # (T, N_tvoc, B)*    
    
    gen ? hcat([permutedims(sampling(permutedims(P, [2,3,1])),[3,1,2]) for P in ŷ_P]...) : ŷ_P
end

function (model::CPTransformer)(x, y; train=true)
    ŷ_P = train ? model(x, y=y) : model(x)
    println("wtf")
#     display(ŷ_P[4])
#     display(reshape(y[:,4,:], (size(y,1),1,:)))
    loss = mean([nll(ŷ_P[i], reshape(y[:,i,:], (size(y,1),1,:)), dims=2) for i in 1:length(n_tokens)])
end

In [9]:
model = CPTransformer(n_tokens, EMBED_SIZES, 512, 2048)
x, y = first(train_loader)
loss = model(x, y)
@show loss

size(x) = (3584, 7, 25)


size(x) = (1376, 3584, 25)




size(x) = (512, 25, 3584)


size(h) = (512, 25, 3584)


size(ŷ_type_P) = (3, 25, 3584)


size(ŷ_type) = (3584, 25)


size(ŷ_τ) = (544, 3584, 25)


size(h_) = (512, 25, 3584)


DimensionMismatch: DimensionMismatch("")

In [10]:
x, y = first(test_loader)
gen = model(x, gen=true)
@show summary.((y, gen))

size.(ŷ_P) = [(3584, 56, 25), (3584, 135, 25), (3584, 18, 25), (3584, 3, 25), (3584, 87, 25), (3584, 18, 25), (3584, 25, 25)]
wtf




size(x) = (3584, 7, 25)


size(x) = (1376, 3584, 25)


size(x) = (512, 25, 3584)


size(h) = (512, 25, 3584)


size(ŷ_type_P) = (3, 25, 3584)




size(ŷ_type) = (3584, 25)
size(ŷ_τ) = (544, 3584, 25)


size(h_) = (512, 25, 3584)


size.(ŷ_P) = [(3584, 56, 25), (3584, 135, 25), (3584, 18, 25), (3584, 3, 25), (3584, 87, 25), (3584, 18, 25), (3584, 25, 25)]




summary.((y, gen)) = ("3584×7×25 Array{Int64, 3}", "3584×7×25 Array{Int64, 3}")


("3584×7×25 Array{Int64, 3}", "3584×7×25 Array{Int64, 3}")

In [11]:
gen