## Parallel Head Data Preprocessing Example

### Background and motivation
In this notebook, we demonstrate how to preprocess data for training models with multiple prediction heads in parallel using the BioNemo Evo2 framework. This approach allows for efficient handling of diverse biological data types, such as RNA-seq and ChIP-seq, by leveraging parallel processing techniques.

For this example, we will focus on preprocessing RNA-seq data from BigWig files and preparing it for model training.

In [None]:
# Replace current config.py with modified version for parallel head support, saving a backup of the original.
!cp /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/config.py /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/config.py.bak
!cp /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/config.py /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/config.py

In [None]:
import os

from bionemo.core.utils.subprocess_utils import run_subprocess_safely  # noqa


data_path = "parallel_head_data"

In [None]:
CLEANUP: bool = True
if CLEANUP and os.path.exists(data_path):
    !rm -rf {data_path}
    !rm -rf ./preprocessed_data
    !rm parallel_preprocess_config.yaml

In [None]:
if not os.path.exists(data_path):
    !mkdir -p {data_path}
    !wget https://storage.googleapis.com/tbb-public-bucket/datasets/parallel-head-example/GCA_000525045.1_DREv1_genomic.fna -O {data_path}/GCA_000525045.1_DREv1_genomic.fna
    !wget https://storage.googleapis.com/tbb-public-bucket/datasets/parallel-head-example/SRR1145649_forward.normalized.bw -O {data_path}/SRR1145649_forward.normalized.bw
    !wget https://storage.googleapis.com/tbb-public-bucket/datasets/parallel-head-example/SRR1145649_reverse.normalized.bw -O {data_path}/SRR1145649_reverse.normalized.bw
    !wget https://storage.googleapis.com/tbb-public-bucket/datasets/parallel-head-example/GCA_000525045.1_DREv1_genomic.gtf -O {data_path}/GCA_000525045.1_DREv1_genomic.gtf

In [None]:
# Let's create a YAML config for preprocessing with RNA-Seq bigwig files.
fasta_base = "GCA_000525045.1_DREv1_genomic.fna"
bigwig_forward = "SRR1145649_forward.normalized.bw"  # No need for reverse, since both are handled together.
full_fasta_path = os.path.abspath(os.path.join(data_path, fasta_base))
output_prefix = "fungi_dna_rnaseq"

output_dir = os.path.abspath("preprocessed_data")
output_yaml = f"""
- datapaths: ["{full_fasta_path}"]
  output_dir: "{output_dir}"
  output_prefix: {output_prefix}
  train_split: 0.9
  valid_split: 0.05
  test_split: 0.05
  overwrite: True
  embed_reverse_complement: true
  random_reverse_complement: 0.0
  random_lineage_dropout: 0.0
  include_sequence_id: false
  transcribe: "back_transcribe"
  force_uppercase: false
  indexed_dataset_dtype: "uint8"
  tokenizer_type: "Byte-Level"
  vocab_file: null
  vocab_size: null
  merges_file: null
  pretrained_tokenizer_model: null
  special_tokens: null
  fast_hf_tokenizer: true
  append_eod: true
  enforce_sample_length: null
  ftfy: false
  workers: 1
  preproc_concurrency: 100000
  chunksize: 25
  drop_empty_sequences: true
  nnn_filter: false  # If you split your fasta on NNN (in human these are contigs), then you should set this to true.
  seed: 12342  # Not relevant because we are not using random reverse complement or lineage dropout.
  fasta_rnaseq_bigwig_map:
    {fasta_base}: {os.path.abspath(os.path.join(data_path, bigwig_forward))}
"""
with open("parallel_preprocess_config.yaml", "w") as f:
    print(output_yaml, file=f)

In [None]:
# Now we can run the preprocessing script with this config.
!python \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads/preprocess.py \
    --config parallel_preprocess_config.yaml

Now that we have a prepared dataset, we can proceed to train our model using the parallel head approach. This involves defining a model architecture that can handle multiple outputs and configuring the training process to optimize for each head simultaneously.

We will use the simple dataset we created in the previous section to illustrate this process.

In [None]:
# First, lets get a model to train
if not os.path.exists("nemo2_evo2_1b_8k"):
    !evo2_convert_to_nemo2 \
      --model-path hf://arcinstitute/savanna_evo2_1b_base \
      --model-size 1b --output-dir nemo2_evo2_1b_8k

In [None]:
# Configure the training dataset
from pathlib import Path


output_pfx = str(Path(os.path.abspath("preprocessed_data")) / output_prefix)
output_yaml = f"""
- dataset_prefix: {output_pfx}_byte-level_train
  dataset_split: train
  dataset_weight: 1.0
- dataset_prefix: {output_pfx}_byte-level_val
  dataset_split: validation
  dataset_weight: 1.0
- dataset_prefix: {output_pfx}_byte-level_test
  dataset_split: test
  dataset_weight: 1.0
"""
with open("training_data_config.yaml", "w") as f:
    print(output_yaml, file=f)

In [None]:
# Now, lets copy folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

# Also copy over loss folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

In [None]:
# Now, lets copy folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

# Also copy over loss folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

# Now lets go ahead and train a model with parallel heads!
WARMUP_STEPS = 100

# For demo purposes, we use a small number of steps. Still 7 hours on 2 A100/RTX6000 Pro GPUs
MAX_STEPS = 1000

# Check validation every 25 steps
VAL_CHECK_INTERVAL = 25

# Use activation checkpointing to save memory
MODEL_SUBNET_OPTION = "--activation-checkpoint-recompute-num-layers 5"

!NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 python \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/run/train_parallel.py \
    -d training_data_config.yaml \
    --dataset-dir ./preprocessed_data \
    --result-dir parallel_pretraining_demo \
    --experiment-name evo2 \
    --model-size 1b \
    --devices 2 \
    --num-nodes 1 \
    --seq-length 8192 \
    --micro-batch-size 4 \
    --lr 0.000015 \
    --min-lr 0.0000149 \
    --warmup-steps {WARMUP_STEPS} \
    --grad-acc-batches 4 \
    --max-steps {MAX_STEPS} \
    --ckpt-dir nemo2_evo2_1b_8k \
    --clip-grad 5 \
    --wd 0.001 \
    --attention-dropout 0.01 \
    --hidden-dropout 0.01 \
    --val-check-interval {VAL_CHECK_INTERVAL} \
    {MODEL_SUBNET_OPTION} \
    --create-tensorboard-logger \
    --parallel-heads \
    --parallel-dna-head \
    --parallel-rna-seq-head \
    --ckpt-async-save

Now that we have a model, lets go ahead and see how well it predicts.

In [None]:
# Make sure you have pyBigWig, gffutils, and Biopython installed in your environment
try:
    import gffutils
    import pyBigWig
    from Bio import SeqIO
except ImportError:
    print("Required packages not found. Installing...")
    %pip install pyBigWig gffutils biopython

In [None]:
# BigWig file parser
# GTF file parsing library
import gffutils

# Import matplotlib for plotting
import matplotlib.pyplot as plt
import pandas as pd
import pyBigWig

# Fasta file parser
from Bio import SeqIO

In [None]:
fasta_file = "./parallel_head_data/GCA_000525045.1_DREv1_genomic.fna"
fasta_sequences = SeqIO.to_dict(SeqIO.parse(fasta_file, "fasta"))

# GTF/GFF directory, select file from metadata
gtf_file = "./parallel_head_data/GCA_000525045.1_DREv1_genomic.gtf"  # Your GTF/GFF directory here
db = gffutils.create_db(
    gtf_file,  # GTF/GFF file path here
    dbfn=":memory:",
    force=True,
    keep_order=True,
    disable_infer_genes=True,
    disable_infer_transcripts=True,
)

# BigWig forward directory, select file from metadata
bigwig_forward_file = "./parallel_head_data/SRR1145649_forward.normalized.bw"  # Your BigWig directory here
bw_forward = pyBigWig.open(bigwig_forward_file)

# BigWig reverse directory, select file from metadata
bigwig_reverse_file = "./parallel_head_data/SRR1145649_reverse.normalized.bw"  # Your BigWig directory here
bw_reverse = pyBigWig.open(bigwig_reverse_file)

In [None]:
# Find first 10 genes in the GTF file that are forward strand since we are looking at a forward strand BigWig
genes = list(db.features_of_type("transcript"))
genes = [gene for gene in genes if gene.strand == "+"]
WINDOW_SIZE = 25

# Let's plot the coverage for the first `total` genes, we will use this for prediction verification later.
total = 1
count = 0

lower_threshold = 20  # Example threshold for max coverage
upper_threshold = 200  # Example upper threshold for max coverage

LOGGING: bool = False
BUILD_FASTA: bool = True
output_fasta = "./parallel_head_data/example_predict.fna"

# For each gene, find any exons to add as reference points on the plot
for gene in genes:
    # Define the region to plot (gene +/- WINDOW_SIZE)
    start = max(0, gene.start - WINDOW_SIZE)  # type: ignore
    end = gene.end + WINDOW_SIZE  # type: ignore

    exons = list(db.children(gene, featuretype="exon", order_by="start"))

    # Get the coverage data from the BigWig file for the gene region
    coverage = bw_forward.values(gene.chrom, start, end, numpy=True)

    # Smooth the coverage data using a simple moving average
    coverage = pd.Series(coverage).rolling(window=6, min_periods=1, center=True).mean().to_numpy()

    max_coverage = coverage.max()

    if lower_threshold is not None:
        if max_coverage < lower_threshold:
            print(
                f"Skipping gene {gene.id} with max coverage {max_coverage} below lower threshold {lower_threshold}"
            ) if LOGGING else None
            continue

    if upper_threshold is not None:
        if max_coverage > upper_threshold:
            print(
                f"Skipping gene {gene.id} with max coverage {max_coverage} above upper threshold {upper_threshold}"
            ) if LOGGING else None
            continue

    print(f"Processing gene: {gene.id} at {gene.chrom}:{start}-{end}") if LOGGING else None
    if BUILD_FASTA:
        # Get the gene sequence from the fasta file
        seq_record = fasta_sequences[gene.chrom]
        gene_sequence = seq_record.seq[start:end]
        # Write to output fasta file
        with open(output_fasta, "a") as fasta_file:
            fasta_file.write(f">{gene.id}\n{gene_sequence}\n")

    # Create x-axis values corresponding to the gene length
    x_values = range(start, end)

    # Plot the coverage
    plt.figure(figsize=(10, 4))
    plt.plot(x_values, coverage, label="RNA-Seq Coverage", color="blue")

    # Add a horizontal line at 10% of the max coverage to indicate baseline
    plt.hlines(
        y=max(coverage) * 0.1,
        xmin=gene.start,  # type: ignore
        xmax=gene.end,  # type: ignore
        colors="grey",
        linewidth=1,
    )

    # Add exon reference as small boxes on the plot connected by lines between
    for exon in exons:
        plt.hlines(
            y=max(coverage) * 0.1,
            xmin=exon.start,  # type: ignore
            xmax=exon.end,  # type: ignore
            colors="green",
            alpha=0.5,
            linewidth=10,
            label="Exon" if exon == exons[0] else "",
        )

    # Add labels and title
    plt.title(f"RNA-Seq Coverage for Gene: {gene.id}")
    plt.xlabel("Genomic Position")
    plt.ylabel("Coverage")
    plt.legend()
    plt.tight_layout()

    # Show the plot
    plt.show()

    count += 1
    if count >= total:
        break

In [None]:
# Now, lets copy folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/heads \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

# Also copy over loss folder /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss to /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/
!cp -r \
    /workspace/sub-packages/bionemo-evo2/src/bionemo/evo2/utils/loss \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/

# Now that we have a fasta file for prediction, we can run inference with our trained model!
INPUT_FASTA = os.path.abspath("./parallel_head_data/example_predict.fna")
CKPT_DIR = os.path.abspath(
    "./parallel_pretraining_demo/evo2/checkpoints/epoch=0-step=999-consumed_samples=32000.0-last"
)
OUTPUT_DIR = os.path.abspath("./parallel_head_data/predictions")

!python \
    /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/heads/predict.py \
    --fasta {INPUT_FASTA} \
    --ckpt-dir {CKPT_DIR} \
    --output-dir {OUTPUT_DIR} \
    --model-size 1b \
    --parallel-heads \
    --parallel-dna-head \
    --parallel-rna-seq-head

In [None]:
CLEANUP: bool = True
if CLEANUP and os.path.exists(data_path):
    !rm -rf {data_path}
    !rm -rf parallel_pretraining_demo
    !rm -rf preprocessed_data
    !rm parallel_preprocess_config.yaml
    !mv /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/config.py.bak /usr/local/lib/python3.12/dist-packages/bionemo/evo2/utils/config.py