1. train/test loss

original code => test loss below train loss

see 08-04 plots

In [None]:
epoch = 0
while true
    global epoch += 1

    epoch_train_losses = Float32[]
    total_train_tp = 0
    total_train_masked = 0

    for start_idx in 1:batch_size:size(X_train_masked, 2)
        end_idx = min(start_idx + batch_size - 1, size(X_train_masked, 2))
        x_batch = gpu(X_train_masked[:, start_idx:end_idx])
        y_batch = gpu(y_train_masked[:, start_idx:end_idx])

        loss_val, grads = Flux.withgradient(model) do m # here is the difference
            loss(m, x_batch, y_batch, "train")
        end
        Flux.update!(opt, model, grads[1])

        push!(epoch_train_losses, loss_val)

        preds_masked = Flux.onecold(logits_masked) |> cpu
        tp = sum(preds_masked .== cpu(y_masked))
        total_masked = length(y_masked)
        total_train_tp += tp
        total_train_masked += total_masked
    end

    # log metrics
    epoch_train_loss = mean(epoch_train_losses)
    epoch_train_acc = 0.0
    if total_train_masked > 0
        epoch_train_acc = total_train_tp / total_train_masked
    end
    push!(train_losses, epoch_train_loss)
    push!(train_accuracies, epoch_train_acc)


    ### start test    
    epoch_test_losses = Float32[]
    total_test_tp = 0
    total_test_masked = 0

    for start_idx in 1:batch_size:size(X_test_masked, 2)
        end_idx = min(start_idx + batch_size - 1, size(X_test_masked, 2))
        x_batch = gpu(X_test_masked[:, start_idx:end_idx])
        y_batch = gpu(y_test_masked[:, start_idx:end_idx])

        test_loss_val, logits_masked, y_masked = loss(model, x_batch, y_batch, "test")

        push!(epoch_test_losses, test_loss_val)

        preds_masked = Flux.onecold(logits_masked) |> cpu
        tp = sum(preds_masked .== cpu(y_masked))
        total_masked = length(y_masked)
        total_test_tp += tp
        total_test_masked += total_masked
    end

    # log metrics
    epoch_test_loss = mean(epoch_test_losses)

    epoch_test_acc = 0.0
    if total_test_masked > 0
        epoch_test_acc = total_test_tp / total_test_masked
    end

    push!(test_losses, epoch_test_loss)
    push!(test_accuracies, epoch_test_acc)

    if epoch % 10 == 0
        run_time = now() - start_time
        total_minutes = div(run_time.value, 60000)
        run_hours = div(total_minutes, 60)
        run_minutes = rem(total_minutes, 60)

        timestamp = Dates.format(now(), "yyyy-mm-dd_HH-MM-SS")
        save_dir = joinpath("plots", "untrt", "indef_masked_rankings", "$(timestamp)_epoch_$(epoch)")
        mkpath(save_dir)

        epochs_ran = 1:epoch

        # loss plot
        loss_plot = Plots.plot(epochs_ran, train_losses; label="train loss",
             xlabel="epoch", ylabel="loss", title="train vs. test loss (epoch $epoch)", lw=2)
        Plots.plot!(loss_plot, epochs_ran, test_losses; label="test loss", lw=2)
        savefig(joinpath(save_dir, "trainval_loss.png"))

        # acc plot
        acc_plot = Plots.plot(epochs_ran, train_accuracies; label="train accuracy",
            xlabel="epoch", ylabel="accuracy", title="train vs. test accuracy (epoch $epoch)", lw=2, legend=:bottomright)
        Plots.plot!(acc_plot, epochs_ran, test_accuracies; label="test accuracy", lw=2)
        savefig(joinpath(save_dir, "accuracy.png"))

        # save acc to csv
        df = DataFrame(
            epoch = epochs_ran,
            train_accuracy = train_accuracies,
            test_accuracy = test_accuracies,
            train_loss = train_losses,
            test_loss = test_losses
            )
        CSV.write(joinpath(save_dir, "metrics.csv"), df)

        # log params
        params_txt = joinpath(save_dir, "params.txt")
        open(params_txt, "w") do io
            println(io, "EPOCH @ SAVE: $epoch")
            println(io, "TIMESTAMP: $timestamp")
            println(io, "--------------------")
        end
    end
end

new code => loss logging corrected

see 08-05 plots

In [None]:
epoch = 0
while true
    global epoch += 1

    ### start train
    epoch_train_losses = Float32[]
    total_train_tp = 0
    total_train_masked = 0

    for start_idx in 1:batch_size:size(X_train_masked, 2)
        end_idx = min(start_idx + batch_size - 1, size(X_train_masked, 2))
        x_batch = gpu(X_train_masked[:, start_idx:end_idx])
        y_batch = gpu(y_train_masked[:, start_idx:end_idx])

        _, grads = Flux.withgradient(model) do m
            loss(m, x_batch, y_batch, "train")
        end
        Flux.update!(opt, model, grads[1])
        loss_val, logits_masked, y_masked = loss(model, x_batch, y_batch, "test")
        
        push!(epoch_train_losses, loss_val)

        if !isempty(y_masked)
            preds_masked = Flux.onecold(logits_masked) |> cpu
            tp = sum(preds_masked .== cpu(y_masked))
            total_masked = length(y_masked)
            total_train_tp += tp
            total_train_masked += total_masked
        end
    end

    # log metrics
    epoch_train_loss = mean(epoch_train_losses)

    epoch_train_acc = 0.0
    if total_train_masked > 0
        epoch_train_acc = total_train_tp / total_train_masked
    end

    push!(train_losses, epoch_train_loss)
    push!(train_accuracies, epoch_train_acc)

    
    ### start test
    epoch_test_losses = Float32[]
    total_test_tp = 0
    total_test_masked = 0

    for start_idx in 1:batch_size:size(X_test_masked, 2)
        end_idx = min(start_idx + batch_size - 1, size(X_test_masked, 2))
        x_batch = gpu(X_test_masked[:, start_idx:end_idx])
        y_batch = gpu(y_test_masked[:, start_idx:end_idx])

        test_loss_val, logits_masked, y_masked = loss(model, x_batch, y_batch, "test")

        push!(epoch_test_losses, test_loss_val)

        if !isempty(y_masked)
            preds_masked = Flux.onecold(logits_masked) |> cpu
            tp = sum(preds_masked .== cpu(y_masked))
            total_masked = length(y_masked)
            total_test_tp += tp
            total_test_masked += total_masked
        end
    end

    # log metrics
    epoch_test_loss = mean(epoch_test_losses)

    epoch_test_acc = 0.0
    if total_test_masked > 0
        epoch_test_acc = total_test_tp / total_test_masked
    end

    push!(test_losses, epoch_test_loss)
    push!(test_accuracies, epoch_test_acc)

    if epoch % 10 == 0
        run_time = now() - start_time
        total_minutes = div(run_time.value, 60000)
        run_hours = div(total_minutes, 60)
        run_minutes = rem(total_minutes, 60)

        timestamp = Dates.format(now(), "yyyy-mm-dd_HH-MM-SS")
        save_dir = joinpath("plots", "untrt", "indef_masked_new", "$(timestamp)_epoch_$(epoch)")
        mkpath(save_dir)

        epochs_ran = 1:epoch

        # loss plot
        loss_plot = Plots.plot(epochs_ran, train_losses; label="train loss",
             xlabel="epoch", ylabel="loss", title="train vs. test loss (epoch $epoch)", lw=2)
        Plots.plot!(loss_plot, epochs_ran, test_losses; label="test loss", lw=2)
        savefig(joinpath(save_dir, "trainval_loss.png"))

        # acc plot
        acc_plot = Plots.plot(epochs_ran, train_accuracies; label="train accuracy",
            xlabel="epoch", ylabel="accuracy", title="train vs. test accuracy (epoch $epoch)", lw=2, legend=:bottomright)
        Plots.plot!(acc_plot, epochs_ran, test_accuracies; label="test accuracy", lw=2)
        savefig(joinpath(save_dir, "accuracy.png"))

        # save acc to csv
        df = DataFrame(
            epoch = epochs_ran,
            train_accuracy = train_accuracies,
            test_accuracy = test_accuracies,
            train_loss = train_losses,
            test_loss = test_losses
            )
        CSV.write(joinpath(save_dir, "metrics.csv"), df)

        # log params
        params_txt = joinpath(save_dir, "params.txt")
        open(params_txt, "w") do io
            println(io, "EPOCH @ SAVE: $epoch")
            println(io, "TIMESTAMP: $timestamp")
        end
    end
end

2. re-typing model structure

previous code:

In [6]:
using Flux, CUDA, JLD2

In [4]:
# so we can use GPU or CPU :D
const IntMatrix2DType = Union{Array{Int}, CuArray{Int, 2}}
const Float32Matrix2DType = Union{Array{Float32}, CuArray{Float32, 2}}
const Float32Matrix3DType = Union{Array{Float32, 3}, CuArray{Float32, 3}}

### positional encoder

struct PosEnc
    pe_matrix::CuArray{Float32,2}
end

function PosEnc(embed_dim::Int, max_len::Int) # max_len is usually maximum length of sequence but here it is just len(genes)
    pe_matrix = Matrix{Float32}(undef, embed_dim, max_len)
    for pos in 1:max_len, i in 1:embed_dim
        angle = pos / (10000^(2*(div(i-1,2))/embed_dim))
        if mod(i, 2) == 1
            pe_matrix[i,pos] = sin(angle) # odd indices
        else
            pe_matrix[i,pos] = cos(angle) # even indices
        end
    end
    return PosEnc(cu(pe_matrix))
end

Flux.@functor PosEnc

function (pe::PosEnc)(input::Float32Matrix3DType)
    seq_len = size(input,2)
    return input .+ pe.pe_matrix[:,1:seq_len] # adds positional encoding to input embeddings
end

### building transformer section

struct Transf
    mha::Flux.MultiHeadAttention
    att_dropout::Flux.Dropout
    att_norm::Flux.LayerNorm # this is the normalization aspect
    mlp::Flux.Chain
    mlp_norm::Flux.LayerNorm
end

function Transf(
    embed_dim::Int, 
    hidden_dim::Int; 
    n_heads::Int, 
    dropout_prob::Float64
    )

    mha = Flux.MultiHeadAttention((embed_dim, embed_dim, embed_dim) => (embed_dim, embed_dim) => embed_dim, 
                                    nheads=n_heads, 
                                    dropout_prob=dropout_prob
                                    )

    att_dropout = Flux.Dropout(dropout_prob)
    
    att_norm = Flux.LayerNorm(embed_dim)
    
    mlp = Flux.Chain(
        Flux.Dense(embed_dim => hidden_dim, gelu),
        Flux.Dropout(dropout_prob),
        Flux.Dense(hidden_dim => embed_dim),
        Flux.Dropout(dropout_prob)
        )
    mlp_norm = Flux.LayerNorm(embed_dim)

    return Transf(mha, att_dropout, att_norm, mlp, mlp_norm)
end

Flux.@functor Transf

function (tf::Transf)(input::Float32Matrix3DType) # input shape: embed_dim × seq_len × batch_size
    normed = tf.att_norm(input)
    atted = tf.mha(normed, normed, normed)[1] # outputs a tuple (a, b)
    att_dropped = tf.att_dropout(atted)
    residualed = input + att_dropped
    res_normed = tf.mlp_norm(residualed)

    embed_dim, seq_len, batch_size = size(res_normed)
    reshaped = reshape(res_normed, embed_dim, seq_len * batch_size) # dense layers expect 2D inputs
    mlp_out = tf.mlp(reshaped)
    mlp_out_reshaped = reshape(mlp_out, embed_dim, seq_len, batch_size)
    
    tf_output = residualed + mlp_out_reshaped
    return tf_output
end

### full model as << ranked data --> token embedding --> position embedding --> transformer --> classifier head >>

struct Model
    embedding::Flux.Embedding
    pos_encoder::PosEnc
    pos_dropout::Flux.Dropout
    transformer::Flux.Chain
    classifier::Flux.Chain
end

function Model(;
    input_size::Int,
    embed_dim::Int,
    n_layers::Int,
    n_classes::Int,
    n_heads::Int,
    hidden_dim::Int,
    dropout_prob::Float64
    )

    embedding = Flux.Embedding(input_size => embed_dim)

    pos_encoder = PosEnc(embed_dim, input_size)

    pos_dropout = Flux.Dropout(dropout_prob)

    transformer = Flux.Chain(
        [Transf(embed_dim, hidden_dim; n_heads, dropout_prob) for _ in 1:n_layers]...
        )

    classifier = Flux.Chain(
        Flux.Dense(embed_dim => embed_dim, gelu),
        Flux.LayerNorm(embed_dim),
        Flux.Dense(embed_dim => n_classes)
        )

    return Model(embedding, pos_encoder, pos_dropout, transformer, classifier)
end

Flux.@functor Model

function (model::Model)(input::IntMatrix2DType)
    embedded = model.embedding(input)
    encoded = model.pos_encoder(embedded)
    encoded_dropped = model.pos_dropout(encoded)
    transformed = model.transformer(encoded_dropped)
    # pooled = dropdims(mean(transformed; dims=2), dims=2)
    logits_output = model.classifier(transformed)
    return logits_output
end

new code:

In [None]:
struct PosEnc1{U<:AbstractMatrix} #!#
    pe_matrix::U
end

function PosEnc1(embed_dim::Int, max_len::Int) # max_len is usually maximum length of sequence but here it is just len(genes)
    pe_matrix = Matrix{Float32}(undef, embed_dim, max_len)
    for pos in 1:max_len, i in 1:embed_dim
        angle = pos / (10000^(2*(div(i-1,2))/embed_dim))
        if mod(i, 2) == 1
            pe_matrix[i,pos] = sin(angle) # odd indices
        else
            pe_matrix[i,pos] = cos(angle) # even indices
        end
    end
    return PosEnc1(pe_matrix)
end

Flux.@functor PosEnc1

function (pe::PosEnc1)(input::Float32Matrix3DType)
    seq_len = size(input,2)
    return input .+ pe.pe_matrix[:,1:seq_len] # adds positional encoding to input embeddings
end

### building transformer section

struct Transf1{MHA<:Flux.MultiHeadAttention, D<:Flux.Dropout, LN<:Flux.LayerNorm, C<:Flux.Chain} #!#
    mha::MHA
    att_dropout::D
    att_norm::LN
    mlp::C
    mlp_norm::LN
end

function Transf1(
    embed_dim::Int, 
    hidden_dim::Int; 
    n_heads::Int, 
    dropout_prob::Float64
    )

    mha = Flux.MultiHeadAttention((embed_dim, embed_dim, embed_dim) => (embed_dim, embed_dim) => embed_dim, 
                                    nheads=n_heads, 
                                    dropout_prob=dropout_prob
                                    )

    att_dropout = Flux.Dropout(dropout_prob)
    
    att_norm = Flux.LayerNorm(embed_dim)
    
    mlp = Flux.Chain(
        Flux.Dense(embed_dim => hidden_dim, gelu),
        Flux.Dropout(dropout_prob),
        Flux.Dense(hidden_dim => embed_dim),
        Flux.Dropout(dropout_prob)
        )
    mlp_norm = Flux.LayerNorm(embed_dim)

    return Transf(mha, att_dropout, att_norm, mlp, mlp_norm)
end

Flux.@functor Transf1

function (tf::Transf1)(input::Float32Matrix3DType) # input shape: embed_dim × seq_len × batch_size
    normed = tf.att_norm(input)
    atted, _ = tf.mha(normed, normed, normed) # outputs a tuple (a, b)
    att_dropped = tf.att_dropout(atted)
    residualed = input + att_dropped
    res_normed = tf.mlp_norm(residualed)

    embed_dim, seq_len, batch_size = size(res_normed)
    reshaped = reshape(res_normed, embed_dim, seq_len * batch_size) # dense layers expect 2D inputs
    mlp_out = tf.mlp(reshaped)
    mlp_out_reshaped = reshape(mlp_out, embed_dim, seq_len, batch_size)
    
    tf_output = residualed + mlp_out_reshaped
    return tf_output
end

struct Model1{E<:Flux.Embedding, P<:PosEnc, D<:Flux.Dropout, T<:Flux.Chain, C<:Flux.Chain} #!#
    embedding::E
    pos_encoder::P
    pos_dropout::D
    transformer::T
    classifier::C
end

function Model1(;
    input_size::Int,
    embed_dim::Int,
    n_layers::Int,
    n_classes::Int,
    n_heads::Int,
    hidden_dim::Int,
    dropout_prob::Float64
    )

    embedding = Flux.Embedding(input_size => embed_dim)

    pos_encoder = PosEnc(embed_dim, input_size)

    pos_dropout = Flux.Dropout(dropout_prob)

    transformer = Flux.Chain(
    (Transf(embed_dim, hidden_dim; n_heads, dropout_prob) for _ in 1:n_layers)...
    )

    classifier = Flux.Chain(
        Flux.Dense(embed_dim => embed_dim, gelu),
        Flux.LayerNorm(embed_dim),
        Flux.Dense(embed_dim => n_classes)
        )

    return Model1(embedding, pos_encoder, pos_dropout, transformer, classifier)
end

Flux.@functor Model1

function (model::Model1)(input::T) where {T<:IntMatrix2DType} 
    embedded = model.embedding(input)
    encoded = model.pos_encoder(embedded)
    encoded_dropped = model.pos_dropout(encoded)
    transformed = model.transformer(encoded_dropped)
    logits_output = model.classifier(transformed)
    return logits_output
end

see testing.jl

3. comparison of inputs => new graphs

scatter plot - binning vs. no binning

heatmap - vs. box scatter