## Abstract

@Rijal2025 recently applied attention mechanisms to the problem of mapping genotype to phenotype. We noted their architecture omitted standard Transformer components, prompting us to test if including these elements could enhance performance on their multi-environment yeast dataset. Our analysis revealed that incorporating standard Transformer elements substantially improves predictive accuracy for this task.

----

:::{.callout-note title="AI usage disclosure" collapse="true"}
This is a placeholder for the AI usage disclosure. Once all authors sign the AI code form on AirTable, SlackBot will message you an AI disclosure that you should place here.
:::

# Introduction

The recent preprint by @Rijal2025 introduces an application of attention mechanisms for inferring genotype-phenotype maps, particularly focusing on capturing complex epistatic interactions. They showed that their attention-based model outperformed linear and linear + pairwise models. Overall, their work sparked considerable interest and discussion within our journal club.

While appreciating the novelty of applying attention in this domain, we noted that the specific architecture employed is relatively minimal. It utilizes stacked attention layers but omits several components commonly found in the original transformer model [@Vaswani2017], such as skip connections, layer normalization, or feed-forward blocks. We found these omissions interesting, since the transformer was the first architecture to fully leverage the power of attention mechanisms, and did so to great success in many distinct domains.

This led us to a very straightforward inquiry: Could the performance of the attention-based genotype-phenotype model proposed by @Rijal2025 be improved by replacing it with a standard transformer architecture?"

## The dataset

The experimental data used in @Rijal2025 comes from the work of @Ba2022, who performed a large-scale quantitative trait locus (QTL) study in yeast. In short, they measured the growth rates of ~100,000 yeast segregants across 18 conditions and for ~40,000 loci, creating a massive dataset suitable for mapping genotype to phenotype.

Due to extensive linkage disequilibrium (LD), the loci in the dataset are highly correlated with each other. To create a set of independent loci, @Rijal2025 a defined a set of loci such that the correlation between the SNPs present at any pair of loci is less than 94%, resulting in a set of 1164 "independent" loci.

Unfortunately, they didn't provide this set of loci, nor the genotypic and phenotypic data used for training, so we located the [raw data](https://datadryad.org/dataset/doi:10.5061/dryad.1rn8pk0vd) that @Ba2022 originally uploaded alongside their study, then used [this notebook](https://github.com/Emergent-Behaviors-in-Biology/GenoPhenoMapAttention/blob/main/obtain_independent_loci.ipynb) uploaded by @Rijal2025 to recapitulate the 1164 loci. To save everyone else the trouble, we uploaded the train, test, and validation datasets we're *pretty sure* @Rijal2025 used in their study.

You can find the data here:

```
s3://2025-attention-is-almost-all-you-need/datasets
```

We use this date in what follows, so let's go ahead and download it into a directory:

In [1]:
import subprocess
from pathlib import Path

dataset_dir = Path("datasets/")
remote_dir = "s3://2025-attention-is-almost-all-you-need/datasets/"

subprocess.run(f"aws s3 sync {remote_dir} {dataset_dir}".split(" "))

CompletedProcess(args=['aws', 's3', 'sync', 's3://2025-attention-is-almost-all-you-need/datasets/', 'datasets'], returncode=0)

## Code infrastructure

Rather than continuing to work with their notebook files, we re-implemented their code into our own codebase to improve code quality and make room for our planned modifications. Here is a summary of the changes we made:

* Added a `RunParams` dataclass that holds all available options for model and training specification.
* Saved the training/validation/test datasets to file and created PyTorch `DataLoader` objects to manage accession, batching, and data shuffling.
* Automated the training loop with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), promoting separation between the training loop and the core scientific logic of their model.
* Added early stopping functionality to end early if learning stagnates.
* Added a parameter to turn skip connections on/off.
* Generalized their model to have an arbitrary number of layers instead of fixing the number of layers to three.

## Reproducing the single environment model results

To be sure we correctly reverse-engineered the specifics of their training/validation/test datasets and accurately re-implemented their model, let's try and reproduce the attention model performance in Figure 3 (red dots):
we 

![Figure 3 from @Rijal2025. Original caption: "*Comparison of model performance in yeast QTL mapping data. We show R2 on test datasets for linear, linear+pairwise, and attention-based model (with d = 12) across 18 phenotypes (relative growth rates in various environments). For the linear + pairwise mode, the causal loci inferred by @Ba2022 are used.*"](assets/fig3.jpg){fig-align="center" width=70% fig-alt="Figure 3 from @Rijal2025 showing the single-environment model performances."}

To do this, we created a high level entry point for training that:

* Trains a model for a given phenotype
* Determines the *best model*, defined as the model with the highest $R^2$ calculated over the *validation* dataset
* Reports the $R^2$ for the *test* dataset using the best model
* Saves the model to file for downstream use

This entrypoint requires information about the model architecture and how the training should proceed. These can be specified via two configuration objects, `ModelConfig` and `TrainConfig`. Here are the `ModelConfig` and `TrainConfig` that match the design of the experiment that produced Figure 3:

In [3]:
from analysis.base import ModelConfig, TrainConfig

model_config = ModelConfig(
    model_type="rijal_et_al",
    seq_length=1164,
    embedding_dim=13,
    num_layers=3,
)

train_config = TrainConfig(
    # data_dir=dataset_dir,
    data_dir=Path("data_subsubset"),
    save_dir=Path("models"),
    name_prefix="reproduce_fig3",
    phenotype="23C",
    optimizer="adam",
    batch_size=64,
    learning_rate=0.001,
    lr_schedule=False,
    weight_decay=0.0,
    max_epochs=5,
    gradient_clip_val=0,
    use_modal=True,
)

This `train_config` specifies the first phenotype, `23C`, but we'll want to run this for is specific to the first phenotype, `23C`, but since

In [4]:
from analysis.train import run_training

run_training(model_config, train_config)

Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Finished training.
Training completed. Run artifacts saved in Modal volume at: /data/models/reproduce_fig3/lightning_logs/version_4
Downloading run directory to models/reproduce_fig3/lightning_logs/version_4...
[?25l[34m⠋[0m Downloading file(s) to local...
[2K[1A[2K[34m⠸[0m Downloading file(s) to local...0[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
[2K[1A[2K[34m⠦[0m Downloading file(s) to local...0[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
[1;37mDownloading file(s) to local...[0m [33m0:00:00[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
                   [1;37mmodels/reproduce_fig3/lightning_logs/version_4/metrics.csv[0m [35m0…[0m
[1;37mmodels/reproduce_fig3/lightning_logs/version_4/events.out.tfevents.174553509…[0m [35m0…[0m
[1;37mmodels/reproduce_fig3/lightning_logs/version_4/events.out.tfevents.174553509…[0m [35m0…[0m
[1;37mmodels/reproduce_fig3/lightning_

PosixPath('models/reproduce_fig3/lightning_logs/version_4')

In [5]:
import attrs
from analysis.dataset import phenotype_names
from analysis.train import run_trainings

jobs = []
for phenotype in phenotype_names:
    phenotype_config = attrs.evolve(train_config, phenotype=phenotype)
    jobs.append((model_config, phenotype_config))

run_trainings(jobs)

Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Starting training job... Check tensorboard server for progress.


All required files exist remotely.
Training completed. Run artifacts saved in Modal volume at: /data/models/reproduce_fig3/lightning_logs/version_5
Downloading run directory to models/reproduce_fig3/lightning_logs/version_5...
[?25l[34m⠋[0m Downloading file(s) to local...
[2K[1A[2K[34m⠸[0m Downloading file(s) to local...0[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
[2K[1A[2K[34m⠦[0m Downloading file(s) to local...0[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
[2K[1A[2K[34m⠏[0m Downloading file(s) to local...0[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [1;37m.[0m
[1;37mDownloading file(s) to local...[0m [33m0:00:00[0m [91m━━━[0m[91m╸[0m[90m━━━━━━[0m [1;37m(6 out of 15 files completed)[0m
[1;37mmodels/reproduce_fig3/lightning_logs/version_5/checkpoints/best-epoch=004-va…[0m [35m0…[0m
[1;37mmodels/reproduce_fig3/lightning_logs/version_5/events.out.tfevents.174553524…[0m [35m0…[0m
[1;37mmodels/reprod

[PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_5'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_6'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_6'),
 PosixPath('models/reproduce_fig3/lightning_logs/version_6'),
 PosixPa