<< to keep track of changes made, prev. ideas, etc. >>

apr 8, 2025 - flux_transf.jl
-

this is what was used for original 2D input into mhsa (which doesn't work, needs (k, q, v) input)

In [None]:
function (tf::Transf)(input::Float32Matrix2DType) # input is features (978) x batch
    sa_out = tf.mhsa(tf.norm_1(input)) # OG paper states norm after sa, but norm before sa is more common?
    # x = input + tf.dropout(sa_out)
    x = input + sa_out
    mlp_out = tf.mlp(tf.norm_2(x))
    # x = x + tf.dropout(mlp_out)
    x = x + mlp_out
    return x
end

apr 11, 2025 - flux_transf.jl
-

using a for loop and replacing a matrix of -1 rather than making a copy is much faster! 
using the function below (commented parts) is a little easier to read, but the current impl. in the file is much more efficient.
tmp: if we were to replace the -1s with not the gene names (Str) but with the ranking ints themselves

In [None]:
function sort_gene(expr)
    # data_ranked = Matrix{Int}(undef, size(expr))
    data_ranked = fill(-1, size(expr))
    # gs = Symbol.(gene_symbols)
    n, m = size(expr)
    p = Vector{Int}(undef, n)
    # tmp = sortperm(expr[:, 1])

    for j in 1:m
        e = view(expr, :, j)
        sortperm!(p, e, rev=true)
        # data_ranked[!, j] = gs[tmp]

        for i in 1:n
            data_ranked[i, j] = p[i]
        end
    end
    return data_ranked
end

GENES AS TOKENS gene-gene interactions (sequence length as len(genes))
- each gene is a position in your sequence
- each token's embedding contains information about that gene's ranking across samples
- the sequence length equals the number of genes you're considering
SAMPLES AS TOKENS sample-sample interactions (sequence length as len(samples))
- each sample is a position in your sequence
- each token's embedding contains the gene ranking information for that sample
- the sequence length equals the number of samples
TECHNICAL CONSIDERATIONS
transformers struggle with very long sequences, so if we have many more genes than samples, using samples as tokens may be more computationally feasible
self-attention complexity grows quadratically with sequence length
which dimension has more examples to learn from?

good option: Flux.MultiHeadAttention((64, 64, 64) => (64, 64) => 64, nheads=1), can incr nheads later
- q, k, v input dim should all be the same if data type is the same or we aren't doing encoder-decoder
- middle dimensions should also be the same unless we want to reduce computational complexity in the middle
- output can also be the same unless we want to do ft compression or expansion

may 1, 2025 - masked loss fxn - DONE
- 

In [None]:
function loss_masked(model, x, y_masked)
    logits = model(x)  # (n_classes, seq_len, batch_size)
    logits = permutedims(logits, (2, 3, 1))  # seq_len × batch_size × n_classes (to match targets)
    logits = reshape(logits, :, n_classes)   # (seq_len * batch_size) × n_classes

    y_masked_flat = vec(y_masked) # flatten
    # only keep where y_masked != -100
    mask = y_masked_flat .!= -100
    logits_masked = (logits[mask, :])'
    targets_masked = y_masked_flat[mask]
    y_oh = Flux.onehotbatch(targets_masked, 1:n_classes)

    return Flux.logitcrossentropy(logits_masked, y_oh)
end

may 27, 2025 - training fxn with masked values for accuracy - DONE
- 

In [None]:
function loss(model, x, y)
    logits = model(x)  # (n_classes, seq_len, batch_size)
    logits_flat = reshape(logits, size(logits, 1), :) # (n_classes, seq_len*batch_size)
    y_flat = vec(y) # (seq_len*batch_size) column vec

    mask = y_flat .!= -100 # bit vec, where sum = n_masked
    logits_masked = logits_flat[:, mask] # (n_classes, n_masked)
    y_masked = y_flat[mask] # (n_masked) column vec

    y_oh = Flux.onehotbatch(y_masked, 1:n_classes) # (n_classes, n_masked)
    return Flux.logitcrossentropy(logits_masked, y_oh) 
end

could return logits_masked and y_masked as well, then do:

In [None]:
preds_masked = Flux.onecold(logits_masked)
preds_masked_cpu = preds_masked |> cpu
preds_masked_cpu .== y_masked
accuracy = sum(preds_masked_cpu .== y_masked) / length(y_masked)

may 29, 2025 - sparse matrices
- 

In [None]:
# in train loop, instead of y_gpu:
y_batch_sparse = get_sparse_batch(y_train_masked, start_idx, end_idx)

using SparseArrays

function mask_input_sparse(X::Matrix{Int64}; mask_ratio=0.10)
    X_masked = copy(X)
    # Create sparse matrix for labels
    I_indices = Int[]  # row indices
    J_indices = Int[]  # column indices  
    values = Int16[]   # actual gene indices
    
    for j in 1:size(X, 2)
        num_masked = ceil(Int, size(X, 1) * mask_ratio)
        mask_positions = randperm(size(X, 1))[1:num_masked]
        
        for pos in mask_positions
            push!(I_indices, pos)
            push!(J_indices, j)
            push!(values, X[pos, j])  # original gene index
            
            X_masked[pos, j] = MASK_ID
        end
    end
    
    y_sparse = sparse(I_indices, J_indices, values, size(X)...)
    return X_masked, y_sparse
end

X_train_masked, y_train_masked = mask_input_sparse(X_train)
X_test_masked, y_test_masked = mask_input_sparse(X_test)

function loss_sparse(model, x, y_sparse_batch, mode)
    logits = model(x)  # (n_classes, seq_len, batch_size)
    
    rows_cpu, cols_cpu, vals_cpu = findnz(y_sparse_batch)
    
    if isempty(rows_cpu)
        return 0.0f0
    end
    
    rows_gpu = cu(rows_cpu)
    cols_gpu = cu(cols_cpu) 
    vals_gpu = cu(vals_cpu)
    
    batch_size = size(logits, 3)
    seq_len = size(logits, 2)
    
    linear_indices = (cols_gpu .- 1) .* seq_len .+ rows_gpu
    
    logits_reshaped = reshape(logits, size(logits, 1), :) # (n_classes, seq_len * batch_size)
    masked_logits = logits_reshaped[:, linear_indices]  # (n_classes, n_masked)
    
    y_oh = Flux.onehotbatch(vals_gpu, 1:n_classes)
    
    if mode == "train"
        return Flux.logitcrossentropy(masked_logits, y_oh)
    elseif mode == "test"
        return Flux.logitcrossentropy(masked_logits, y_oh), masked_logits, vals_gpu
    end
end

function get_sparse_batch(y_sparse, start_idx, end_idx)
    rows, cols, vals = findnz(y_sparse[:, start_idx:end_idx])
    cols_adjusted = cols
    batch_sparse = sparse(rows, cols_adjusted, vals, size(y_sparse, 1), end_idx - start_idx + 1)
    
    return batch_sparse
end

** with the above code, 11mins 1 epoch (see github @ this time for other params) VS. 11mins 1 epoch for dense representations.
can revisit later if there is found to be memory bottlenecks @ matrix operations, however for now the dense is sufficient b/c:
https://www.reddit.com/r/Julia/comments/108g5ou/when_is_it_worth_working_with_sparse_matrices/
https://medium.com/data-science/sparse-matrices-in-pytorch-part-2-gpus-fd9cc0725b71
- above state that there should be about 1% or less sparsity for GPU sparse matrices to be efficient
- due to the need to repeatedly transfer data between CPU and GPU and sparse slicing operations in batching

june 12, 2025 trying dynamic masking;
- 

**RoBERTa shows that masking a different subset every epoch already helps, 
but recent work finds that decreasing the rate during training is even better 
(to try later!!! aka scheduler)

***dynamic on the train, static on the test

jun 16, 2025 - tried to only mask 1 position; if it can't predict just the missing # from 1-978, then it's dumb.
- 

- the issue here might be that the test mask is different from the train mask; 
thus w/o dynamic masking, the train learns a single value each time, and the test provides a new value not seen before..?

also trying to profile the memory b/c 25min per epoch is way too long;
Profile.Allocs.@profile sample_rate=1 begin/end is taking wayyyyy too long too (~3hrs so far) to run in the REPL.. maybe
is there another way?

*also what takes 26h on kraken takes 4h on smaug...
there seems to be slightly better trainval loss using higher embed dim --> try higher dim

jun 20, 2025 - based on the results from smaug, 2025-06-19_21-48
- 

there seems to be an issue with learning diff masked tokens (1 per sample) across the whole dataset
it works if we want to do the same say, 5 masks across the whole dataset (albeit at only a 84% accuracy for some reason)
1. double check masking function - ensure that it is correct
2. scale up learning rate..? not sure what else to do here
the training masks have sufficinet examples to learn from i think (60 per label) so what is going wrong?


jul 24, 2025 - speed/memory optimization
-

- lux.jl is equivalent of jax in python (uses xla backend)
- unrelated, but wb a splitted data struct..? (what lea uses)

jul 25, 2025 - exp masking...?
- 

1. can you embed counts? how would that work
2. needs to be a regression output (1 value) rather than a vector of probabilities per class?
3. loss needs to be defined differently
4. should it just be a dense network? does MHA work on 

jul 30, 2025 - exp masking
- 

TODO:
- Dense layer instead of Flux.Embedding 
    - this is b/c below static vs. dynamic Embeddings
    - dense layer stores f(x) = Wx + b, where W and b are learned during training for all genes/samples
    - W is a matrix, b is a vector and these are multiplied by the input x to output f(x) the embeddign vector 
- MHA still works with non-tokenized values!
    - model learns the meaning of the expression levels rather than the gene identity/ranks?
- mask token should be 0.0 after normalizing input
- MSE loss
- regression output to 1 value

Static Embeddings (like Word2Vec or Flux.Embedding): 
- The model learns one single, fixed vector for each unique token (e.g., the word "bank"). 
- The goal is to learn the meaning of the token itself by averaging its usage across thousands of different contexts.
Dynamic Feature Representation (Your Model): 
- Your model does not learn a static vector for "Gene 1". 
- Instead, it learns a function that maps any given expression value to a vector representation.

9PM - edit to exp masking
- had to move loss calc to before model update - resulted in lower loss in test than train
- change masking val to -1, apparenlty there are exp elvels of 0 in the original dataset?

***shoudl make a plot similar to this scatter plot for check_error.jl file


aug 1, 2025 - pre lab-meeting
- 

***need to fix the logging of params for predstrues.csv and params.txt ; didn't save in the last couple runs

aug 4, 2025 - debug + seb loss issue
-

- runnign on GPU 2 for indef run - rerun for seb
- need to 
    1. figure out logging for test/train - is it diff than what's in the slides?
    2. debug original structure typing
    3. fix logging of params for predstrues.csv and params.txt when doing input comparison plots
    4. fix progressbar for indef_run.jl (why is it out of 628, repeats for each epoch???)

- 08-04 run is original test/loss definitions from creating 720ep graph
- changed code to use mask_transf code in the while loop
- so Flux.withgradient = 1st loss calc --> update! --> second loss calc

aug 5, 2025 - debug
-

original structure:

In [None]:
struct PosEnc
    pe_matrix::CuArray{Float32,2}
end

function PosEnc(embed_dim::Int, max_len::Int) # max_len is usually maximum length of sequence but here it is just len(genes)
    pe_matrix = Matrix{Float32}(undef, embed_dim, max_len)
    for pos in 1:max_len, i in 1:embed_dim
        angle = pos / (10000^(2*(div(i-1,2))/embed_dim))
        if mod(i, 2) == 1
            pe_matrix[i,pos] = sin(angle) # odd indices
        else
            pe_matrix[i,pos] = cos(angle) # even indices
        end
    end
    return PosEnc(cu(pe_matrix))
end

Flux.@functor PosEnc

function (pe::PosEnc)(input::Float32Matrix3DType)
    seq_len = size(input,2)
    return input .+ pe.pe_matrix[:,1:seq_len] # adds positional encoding to input embeddings
end

### building transformer section

struct Transf
    mha::Flux.MultiHeadAttention
    att_dropout::Flux.Dropout
    att_norm::Flux.LayerNorm # this is the normalization aspect
    mlp::Flux.Chain
    mlp_norm::Flux.LayerNorm
end

function Transf(
    embed_dim::Int, 
    hidden_dim::Int; 
    n_heads::Int, 
    dropout_prob::Float64
    )

    mha = Flux.MultiHeadAttention((embed_dim, embed_dim, embed_dim) => (embed_dim, embed_dim) => embed_dim, 
                                    nheads=n_heads, 
                                    dropout_prob=dropout_prob
                                    )

    att_dropout = Flux.Dropout(dropout_prob)
    
    att_norm = Flux.LayerNorm(embed_dim)
    
    mlp = Flux.Chain(
        Flux.Dense(embed_dim => hidden_dim, gelu),
        Flux.Dropout(dropout_prob),
        Flux.Dense(hidden_dim => embed_dim),
        Flux.Dropout(dropout_prob)
        )
    mlp_norm = Flux.LayerNorm(embed_dim)

    return Transf(mha, att_dropout, att_norm, mlp, mlp_norm)
end

Flux.@functor Transf

function (tf::Transf)(input::Float32Matrix3DType) # input shape: embed_dim × seq_len × batch_size
    normed = tf.att_norm(input)
    atted = tf.mha(normed, normed, normed)[1] # outputs a tuple (a, b)
    att_dropped = tf.att_dropout(atted)
    residualed = input + att_dropped
    res_normed = tf.mlp_norm(residualed)

    embed_dim, seq_len, batch_size = size(res_normed)
    reshaped = reshape(res_normed, embed_dim, seq_len * batch_size) # dense layers expect 2D inputs
    mlp_out = tf.mlp(reshaped)
    mlp_out_reshaped = reshape(mlp_out, embed_dim, seq_len, batch_size)
    
    tf_output = residualed + mlp_out_reshaped
    return tf_output
end

### full model as << ranked data --> token embedding --> position embedding --> transformer --> classifier head >>

struct Model
    embedding::Flux.Embedding
    pos_encoder::PosEnc
    pos_dropout::Flux.Dropout
    transformer::Flux.Chain
    classifier::Flux.Chain
end

function Model(;
    input_size::Int,
    embed_dim::Int,
    n_layers::Int,
    n_classes::Int,
    n_heads::Int,
    hidden_dim::Int,
    dropout_prob::Float64
    )

    embedding = Flux.Embedding(input_size => embed_dim)

    pos_encoder = PosEnc(embed_dim, input_size)

    pos_dropout = Flux.Dropout(dropout_prob)

    transformer = Flux.Chain(
        [Transf(embed_dim, hidden_dim; n_heads, dropout_prob) for _ in 1:n_layers]...
        )

    classifier = Flux.Chain(
        Flux.Dense(embed_dim => embed_dim, gelu),
        Flux.LayerNorm(embed_dim),
        Flux.Dense(embed_dim => n_classes)
        )

    return Model(embedding, pos_encoder, pos_dropout, transformer, classifier)
end

Flux.@functor Model

function (model::Model)(input::IntMatrix2DType)
    embedded = model.embedding(input)
    encoded = model.pos_encoder(embedded)
    encoded_dropped = model.pos_dropout(encoded)
    transformed = model.transformer(encoded_dropped)
    # pooled = dropdims(mean(transformed; dims=2), dims=2)
    logits_output = model.classifier(transformed)
    return logits_output
end

re-typed structure:

In [None]:
struct PosEnc{U<:AbstractMatrix}
    pe_matrix::U
end

function PosEnc(embed_dim::Int, max_len::Int) # max_len is usually maximum length of sequence but here it is just len(genes)
    pe_matrix = Matrix{Float32}(undef, embed_dim, max_len)
    for pos in 1:max_len, i in 1:embed_dim
        angle = pos / (10000^(2*(div(i-1,2))/embed_dim))
        if mod(i, 2) == 1
            pe_matrix[i,pos] = sin(angle) # odd indices
        else
            pe_matrix[i,pos] = cos(angle) # even indices
        end
    end
    return PosEnc(pe_matrix)
end

Flux.@functor PosEnc

function (pe::PosEnc)(input::Float32Matrix3DType)
    seq_len = size(input,2)
    return input .+ pe.pe_matrix[:,1:seq_len] # adds positional encoding to input embeddings
end

### building transformer section

struct Transf{MHA<:Flux.MultiHeadAttention, D<:Flux.Dropout, LN<:Flux.LayerNorm, C<:Flux.Chain}
    mha::MHA
    att_dropout::D
    att_norm::LN
    mlp::C
    mlp_norm::LN
end

function Transf(
    embed_dim::Int, 
    hidden_dim::Int; 
    n_heads::Int, 
    dropout_prob::Float64
    )

    mha = Flux.MultiHeadAttention((embed_dim, embed_dim, embed_dim) => (embed_dim, embed_dim) => embed_dim, 
                                    nheads=n_heads, 
                                    dropout_prob=dropout_prob
                                    )

    att_dropout = Flux.Dropout(dropout_prob)
    
    att_norm = Flux.LayerNorm(embed_dim)
    
    mlp = Flux.Chain(
        Flux.Dense(embed_dim => hidden_dim, gelu),
        Flux.Dropout(dropout_prob),
        Flux.Dense(hidden_dim => embed_dim),
        Flux.Dropout(dropout_prob)
        )
    mlp_norm = Flux.LayerNorm(embed_dim)

    return Transf(mha, att_dropout, att_norm, mlp, mlp_norm)
end

Flux.@functor Transf

function (tf::Transf)(input::Float32Matrix3DType) # input shape: embed_dim × seq_len × batch_size
    normed = tf.att_norm(input)
    atted, _ = tf.mha(normed, normed, normed) # outputs a tuple (a, b)
    att_dropped = tf.att_dropout(atted)
    residualed = input + att_dropped
    res_normed = tf.mlp_norm(residualed)

    embed_dim, seq_len, batch_size = size(res_normed)
    reshaped = reshape(res_normed, embed_dim, seq_len * batch_size) # dense layers expect 2D inputs
    mlp_out = tf.mlp(reshaped)
    mlp_out_reshaped = reshape(mlp_out, embed_dim, seq_len, batch_size)
    
    tf_output = residualed + mlp_out_reshaped
    return tf_output
end

struct Model{E<:Flux.Embedding, P<:PosEnc, D<:Flux.Dropout, T<:Flux.Chain, C<:Flux.Chain}
    embedding::E
    pos_encoder::P
    pos_dropout::D
    transformer::T
    classifier::C
end

function Model(;
    input_size::Int,
    embed_dim::Int,
    n_layers::Int,
    n_classes::Int,
    n_heads::Int,
    hidden_dim::Int,
    dropout_prob::Float64
    )

    embedding = Flux.Embedding(input_size => embed_dim)

    pos_encoder = PosEnc(embed_dim, input_size)

    pos_dropout = Flux.Dropout(dropout_prob)

    transformer = Flux.Chain(
    (Transf(embed_dim, hidden_dim; n_heads, dropout_prob) for _ in 1:n_layers)...
    )

    classifier = Flux.Chain(
        Flux.Dense(embed_dim => embed_dim, gelu),
        Flux.LayerNorm(embed_dim),
        Flux.Dense(embed_dim => n_classes)
        )

    return Model(embedding, pos_encoder, pos_dropout, transformer, classifier)
end

Flux.@functor Model

function (model::Model)(input::T) where {T<:IntMatrix2DType} 
    # there is an issue here - where type is Any from the Flux portion
    # now - if Flux is causing issues, go into source code and redefine as above:
    # (m::Embedding)(x::T) where {T<:AbstractArray} = reshape(m(vec(x)), :, size(x)...), copied from Flux source code
    # AND
    # input::T where T<:type, allows it to be distinguished as a subtype of the input type
    # should theoretically be able to avoid Anys, and be type-stable!
    embedded = model.embedding(input)
    encoded = model.pos_encoder(embedded)
    encoded_dropped = model.pos_dropout(encoded)
    transformed = model.transformer(encoded_dropped)
    pooled = dropdims(mean(transformed; dims=2), dims=2)
    logits_output = model.classifier(pooled)
    return logits_output
    return embedded
end

TODO:
- need to 
    1. ~~re-run with fixed typing~~ faster.jl running on kraken gpu 1

    2. ~~figure out why test is better than train for loss/accuracy (potentially change in indef_run code or run mask_transf code for x epochs if test is only better in indef_run and not mask_transf)~~ ~~indef_run.jl code changed, running new on smaug gpu 2 ONCE OLD_INDEF_RUN.JL gets to 40!!! --> running new one now! what was the diff bruh~~ doen + clarified fix, see indef_masked_rankings 08-04 vs. 08-05. issue was the withgradient (AGAIN!!!)

    3. ~~fix param logging for exp_transf + mask_transf (predstrues.csv, params.txt)~~

    4. ~~fix progressbar for indef_run.jl~~ removed progress bar lol

    5. ~~fix scatter plot for mask_transf comparison~~ ~~mask_transf_err.jl running on kraken gpu 0~~ done, see masked_rankings/2025-08-05

    6. ~~x-bin for exp_transf comparison~~ ~~exp_transf.jl running on smaug gpu 3~~ done, see masked_expression/2025-08-05

    7. reorganize exp, mask, indef, faster for tomorrow

    8. put fxns/structs into separate src files! more organized.

- 08-04 run is original test/loss calculations from creating 720ep graph
- 08-05 run is updated calculations from mask_transf code

aug 6, 2025 - recap
=

asap:
- ~~exp_transf.jl running on kraken 0 (10ep x-bin)~~ done
- ~~mask_transf_err.jl running on smaug 3 (10ep heatmap)~~ done

still pending:
- faster.jl running on kraken gpu 1 --> old code: 668774 ms, new code:
    - terminated - need to fix code
- reorganize exp, mask, indef, faster
- put fxns/structs into separate src files! more organized.
- ~~why not: exp transf run on untrt - kraken 0~~

aug 12, 2025 - predicting the average, ensuring no repeats, fixing plots
=

TODO:
- ~~redo heatmap/x-bin into boxplots or hex-bin?~~
    - exp on kraken 0, rank on smaug 0
- ~~see if model is just prev the avg rather than acc learning (raw exp)~~
    - in exp code, running 100ep on kraken 0 for comparison
- see if possible to ensure model has no repeats (via permutations, inductive bias, pointer networks)
- put fxns/structs into separate src files! more organized.
- do test iwthout pretrain to see if masking even helps
- faster.jl ; compare memory usage still high!!

aug 14, 2025 - for ensuring no repeats
=

a pointer network is designed to select its output from the elements that are **present in the input sequence**. however, the correct answer (the masked number) is the one element that is explicitly absent from the input.

https://kierszbaumsamuel.medium.com/pointer-networks-what-are-they-c3cb68fae076#:~:text=Notice%20how%20they%20are%20placed,-%20output%20dictionary%3A

https://arxiv.org/abs/2006.06380

alternatively:

can have two sets of inputs:
- context: masked sequence [1, 2, ..., MASK, 79, ...].
- candidate: complete, unmasked set of all possible tokens [1, 2, ..., 978].

where:
- encoder processes the context input to understand what's missing.
- decoder or attention mechanism then uses this context to point to the correct token within the candidate input.

thus, the model isn't pointing to the sequence it was given but to a complete "dictionary" of possibilities, using the masked sequence to figure out which item in the dictionary is the right one.