## Abstract

@Rijal2025 recently applied attention mechanisms to the problem of mapping genotype to phenotype. We noted their architecture omitted standard Transformer components, prompting us to test if including these elements could enhance performance on their multi-environment yeast dataset. Our analysis revealed that incorporating standard Transformer elements substantially improves predictive accuracy for this task.

----

:::{.callout-note title="AI usage disclosure" collapse="true"}
This is a placeholder for the AI usage disclosure. Once all authors sign the AI code form on AirTable, SlackBot will message you an AI disclosure that you should place here.
:::

# Introduction

The recent preprint by @Rijal2025 introduces an application of attention mechanisms for inferring genotype-phenotype maps, particularly focusing on capturing complex epistatic interactions. They showed that their attention-based model outperformed linear and linear + pairwise models. Overall, their work sparked considerable interest and discussion within our journal club.

While appreciating the novelty of applying attention in this domain, we noted that the specific architecture employed is relatively minimal. It utilizes stacked attention layers but omits several components commonly found in the original transformer model [@Vaswani2017], such as skip connections, layer normalization, or feed-forward blocks. We found these omissions interesting, since the transformer was the first architecture to fully leverage the power of attention mechanisms, and did so to great success in many distinct domains.

This led us to a very straightforward inquiry: Could the performance of the attention-based genotype-phenotype model proposed by @Rijal2025 be improved by replacing it with a standard transformer architecture?"

## The dataset

The experimental data used in @Rijal2025 comes from the work of @Ba2022, who performed a large-scale quantitative trait locus (QTL) study in yeast. In short, they measured the growth rates of ~100,000 yeast segregants across 18 conditions and for ~40,000 loci, creating a massive dataset suitable for mapping genotype to phenotype.

Due to extensive linkage disequilibrium (LD), the loci in the dataset are highly correlated with each other. To create a set of independent loci, @Rijal2025 a defined a set of loci such that the correlation between the SNPs present at any pair of loci is less than 94%, resulting in a set of 1164 "independent" loci.

Unfortunately, they didn't provide this set of loci, nor the genotypic and phenotypic data used for training, so we located the [raw data](https://datadryad.org/dataset/doi:10.5061/dryad.1rn8pk0vd) that @Ba2022 originally uploaded alongside their study, then used [this notebook](https://github.com/Emergent-Behaviors-in-Biology/GenoPhenoMapAttention/blob/main/obtain_independent_loci.ipynb) uploaded by @Rijal2025 to recapitulate the 1164 loci. To save everyone else the trouble, we uploaded the train, test, and validation datasets we're *pretty sure* @Rijal2025 used in their study.

You can find the data here:

```
s3://2025-attention-is-almost-all-you-need/datasets
```

This analysis will use this data, so let's go ahead and download it into a directory:

In [1]:
import subprocess
from pathlib import Path

dataset_dir = Path("datasets/")
remote_dir = "s3://2025-attention-is-almost-all-you-need/datasets/"

subprocess.run(f"aws s3 sync {remote_dir} {dataset_dir}".split(" "))

CompletedProcess(args=['aws', 's3', 'sync', 's3://2025-attention-is-almost-all-you-need/datasets/', 'datasets'], returncode=0)

## Reproducing the results

To be sure we correctly reverse-engineered the specifics of their training/validation/test datasets, we used their Jupyter notebooks to reproduce their single-environment attention models shown in Figure 3:

![Figure 3 from @Rijal2025. Original caption: "*Comparison of model performance in yeast QTL mapping data. We show R2 on test datasets for linear, linear+pairwise, and attention-based model (with d = 12) across 18 phenotypes (relative growth rates in various environments). For the linear + pairwise mode, the causal loci inferred by @Ba2022 are used.*"](assets/fig3.jpg){fig-align="center" width=70% fig-alt="Figure 3 from @Rijal2025 showing the single-environment model performances."}

Each model trained in roughly TODO:Erin hours, and for simplicity we offloaded this analysis to a Jupyter notebook you can find [here](TODO:Erin). 

TODO:Erin

From this we concluded that we can reproduce their results and we're using the same / functionally equivalent dataset partitioning that they used.

## Code infrastructure

Rather than continuing to work with their notebook files, we re-implemented their code into our own codebase to improve code quality and make room for our planned modifications. Here is a summary of the changes we made:

* Added a `RunParams` dataclass that holds all available options for model and training specification.
* Saved the training/validation/test datasets to file and created PyTorch `DataLoader` objects to manage accession, batching, and data shuffling.
* Automated the training loop with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), promoting separation between the training loop and the core scientific logic of their model.
* Added early stopping functionality to end early if learning stagnates.
* Added a parameter to turn skip connections on/off.
* Generalized their model to have an arbitrary number of layers instead of fixing the number of layers to three.

## Reproducing their results (again)

To make sure our re-implementation still reproduces Figure 3, we created a function, `train_single_phenotype`, that trains their model for a given phenotype, determines the best model defined as the model with the highest *validation* dataset $R^2$, reports the *test* dataset $R^2$ using the best model, then saves the model to file for downstream use.

`train_single_phenotype` accepts a dataclass (`RunParams`) as input that contains all required parameters, so let's create this object to match the run specifications of @Rijal2025.

In [None]:
from analysis.rijal_et_al import RijalEtAlConfig

config = RijalEtAlConfig(
    data_dir=dataset_dir,
    save_dir=Path("models"),
    name_prefix="reproduce",
    embedding_dim=13,
    num_layers=3,
    batch_size=64,
    learning_rate=0.001,
    max_epochs=200,
)


In [5]:
train_single_phenotype(config, "23C")

Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | StackedAttention | 30.6 K | train
1 | loss_fn | MSELoss          | 0      | train
-----------------------------------------------------
30.6 K    Trainable params
0         Non-trainable params
30.6 K    Total params
0.122     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Epoch 0:   8%|▊         | 87/1129 [00:38<07:38,  2.27it/s, v_num=0, train_loss=0.937]


Detected KeyboardInterrupt, attempting graceful shutdown ...
libc++abi: libc++abi: libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipeterminating due to uncaught exception of type std::__1::system_error: Broken pipeterminating due to uncaught exception of type std::__1::system_error: Broken pipe


libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe


RuntimeError: DataLoader worker (pid 97757) is killed by signal: Abort trap: 6. 

## Adding skip connections improves performance

@Rijal2025 include an important note in their Discussion:

> There are numerous interesting potential extensions and future research directions using attention-based models. These include exploring alternative architectures for how genetic and environmental tokens interact and the use of non-linear MLP layers and skip connections.

Since skip connections form an important component of the transformer, this seemed like a good first step for us. Also, unlike other components of the transformer, skip connections don't introduce parameters count in the model, so we can do an-apples-to apples comparison with/without skip connections without worrying about difference in model complexity.

In [None]:
from analysis.dataset import create_dataloaders, phenotype_names

for phenotype_name in phenotype_names:
    dataloaders = create_dataloaders(dataset_dir)
