In [2]:
! wget http://sgd-archive.yeastgenome.org/sequence/S288C_reference/genome_releases/S288C_reference_genome_R64-5-1_20240529.tgz

--2025-06-15 15:27:54--  http://sgd-archive.yeastgenome.org/sequence/S288C_reference/genome_releases/S288C_reference_genome_R64-5-1_20240529.tgz
Resolving sgd-archive.yeastgenome.org (sgd-archive.yeastgenome.org)... 52.218.222.18, 52.92.154.75, 52.92.235.35, ...
Connecting to sgd-archive.yeastgenome.org (sgd-archive.yeastgenome.org)|52.218.222.18|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20979823 (20M) [application/gzip]
Saving to: ‘S288C_reference_genome_R64-5-1_20240529.tgz’


2025-06-15 15:27:57 (7.00 MB/s) - ‘S288C_reference_genome_R64-5-1_20240529.tgz’ saved [20979823/20979823]



In [3]:
! tar -xvzf S288C_reference_genome_R64-5-1_20240529.tgz

._S288C_reference_genome_R64-5-1_20240529
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
S288C_reference_genome_R64-5-1_20240529/
S288C_reference_genome_R64-5-1_20240529/._rna_coding_R64-5-1_20240529.fasta.gz
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
S288C_reference_genome_R64-5-1_20240529/rna_coding_R64-5-1_20240529.fasta.gz
S288C_reference_genome_R64-5-1_20240529/S288C_reference_sequence_R64-5-1_20240529.fsa.gz
S288C_reference_genome_R64-5-1_20240529/._NotFeature_R64-5-1_20240529.fasta.gz
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
S288C_reference_genome_R64-5-1_20240529/NotFeature_R64-5-1_20240529.fasta.gz
S288C_reference_genome_R64-5-1_20240529/._orf_trans_all_R64-5-1_20240529.fasta.gz
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
S288C_reference_genome_R64-5-1_20240529/orf_trans_all_R64-5-1_20240529.fasta.gz
S288C_refe

In [4]:
!gzip -dkf S288C_reference_genome_R64-5-1_20240529/S288C_reference_sequence_R64-5-1_20240529.fsa.gz
!mv S288C_reference_genome_R64-5-1_20240529/S288C_reference_sequence_R64-5-1_20240529.fsa fasta_file.fsa
!gzip -dkf S288C_reference_genome_R64-5-1_20240529/saccharomyces_cerevisiae_R64-5-1_20240529.gff.gz
!mv S288C_reference_genome_R64-5-1_20240529/saccharomyces_cerevisiae_R64-5-1_20240529.gff gff_file.gff

In [1]:
import math

import numpy as np
import pandas as pd
import torch
from Bio import SeqIO
from datasets import Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    AutoTokenizer,
    BertConfig,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding,
)

# Understanding SpeciesLM Model Types

**SpeciesLM** offers two types of models with different training focuses:

## 🧬 **Upstream Models** (`species_upstream_1000_k1`)
- **Training Data**: 1000bp upstream regions of genes (5' regulatory regions)
- **Focus**: Promoters, enhancers, transcription factor binding sites
- **Best For**: 
  - Gene expression prediction ⭐ (Our use case)
  - Transcriptional regulation analysis
  - Promoter identification
  - Enhancer/silencer detection

## 🧬 **Downstream Models** (`downstream_species_lm`)  
- **Training Data**: Downstream regions of genes (3' regulatory regions)
- **Focus**: Terminators, polyadenylation signals, 3' UTR elements
- **Best For**:
  - mRNA stability prediction
  - Post-transcriptional regulation
  - Termination signal identification
  - microRNA binding site analysis

## 📊 **For RNA Expression Prediction**
We use the **upstream model** because gene expression is primarily controlled by upstream regulatory elements like promoters and enhancers that initiate transcription.

---

In [None]:
def preprocess_sequences(
    ds,
    tokenizer,
    proxy_species,
    seq_col="five_prime_seq",
    num_proc=16,
):
    # Tokenize the existing sequences without adding masked versions
    def tokenize(example):
        return tokenizer(proxy_species + " " + " ".join(example[seq_col]))

    hf_ds = Dataset.from_pandas(ds)
    tok_ds = hf_ds.map(tokenize, num_proc=num_proc)
    tok_ds = tok_ds.flatten_indices()

    return tok_ds

In [None]:
class DNALMReconstructor:
    def __init__(
        self,
        model_path="/s/project/denovo-prosit/JohannesHingerl/BERTADN/final_models/species_upstream_1000_k1/",
        tokenizer_path=None,
        proxy_species="kazachstania_africana_cbs_2517_gca_000304475",
        use_hooks=False,
    ):
        """
        Initialize the DNALMReconstructor with a pre-trained language model.
        :param model_path: Path to the pre-trained model.
        """
        tokenizer_path = tokenizer_path = (
            tokenizer_path if tokenizer_path else model_path
        )
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.config = BertConfig.from_pretrained(model_path)
        self.model = AutoModelForMaskedLM.from_pretrained(
            self.model_path, config=self.config
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        print(f"Using device: {self.device}")
        self.nuc_idx = [self.tokenizer.encode(nuc)[1] for nuc in ["A", "C", "G", "T"]]
        self.proxy_species = proxy_species
        self.activation = None

    def get_vectors_for_each_nucleotide(
        self,
        sequence: str,
        window_size: int = 1000,
        stride: int = 200,
        num_proc: int = 16,
        batch_size: int = 64,
        layers_from: int = 8,
    ):
        "Returns a np.array of len(seq), emb_dim vectors"
        embedding_dim = self.config.hidden_size
        print(f"Embedding dimension: {embedding_dim}")
        seq_len = len(sequence)

        sum_embeddings = torch.zeros(
            (seq_len, embedding_dim), dtype=torch.float32, device=self.device
        )
        count_embeddings = torch.zeros(seq_len, dtype=torch.float32, device=self.device)

        # Prepare start positions list for overlapping windows
        sequences = []
        start_positions = []
        for start_pos in range(0, seq_len, stride):
            end_pos = start_pos + window_size
            if end_pos > seq_len:
                start_pos = seq_len - window_size
                end_pos = seq_len
            current_chunk_chars = (
                sequence[start_pos:end_pos] + "ATG"
            )  # The model was trained on upstream + start codon sequences, we don't want it to be out of distribution
            sequences.append(current_chunk_chars)
            start_positions.append(start_pos)
        df = pd.DataFrame({"seq": sequences, "start_pos": start_positions})

        tok_ds = preprocess_sequences(
            df, self.tokenizer, self.proxy_species, "seq", num_proc=num_proc
        )

        for i in tqdm(range(math.ceil(len(tok_ds) / batch_size))):
            batch = torch.tensor(
                tok_ds[i * batch_size : (i + 1) * batch_size]["input_ids"]
            )
            idx = tok_ds[i * batch_size : (i + 1) * batch_size]["start_pos"]
            idx = torch.tensor(idx, dtype=torch.long, device=self.device)

            with torch.no_grad():
                res = self.model(batch.to(self.device), output_hidden_states=True)[
                    "hidden_states"
                ]  # Layers * (B, L+6, H)
                layer_embeddings = res[layers_from:]
            mean_embedding = torch.mean(
                torch.stack(layer_embeddings, dim=0), dim=0
            )  # (B, L+6, H)
            mean_embedding = mean_embedding[
                :, 2:-4, :
            ]  # species, begginigng and ATG end
            assert mean_embedding.shape[0] == len(idx), (
                f"Mean embedding shape: {mean_embedding.shape}, idx length: {len(idx)}"
            )

            B, L, H = mean_embedding.shape
            position_offsets = torch.arange(L, device=idx.device)  # shape (L,)
            positions = idx.unsqueeze(1) + position_offsets  # shape (B, L)

            # flatten:
            positions_flat = positions.reshape(-1)  # (B*L,)
            values_flat = mean_embedding.reshape(-1, H)  # (B*L, H)
            ones = torch.ones(B * L, device=idx.device)  # (B*L,)

            # now scatter‐add:
            sum_embeddings.index_add_(0, positions_flat, values_flat)
            count_embeddings.index_add_(0, positions_flat, ones)
        num_zeros = (count_embeddings == 0).sum().item()

        assert num_zeros == 0, (
            f"Some nucleotides are not covered by any window ({num_zeros})"
        )
        count_embeddings = count_embeddings.unsqueeze(1)
        return sum_embeddings.cpu().to(torch.float32) / count_embeddings.cpu().to(
            torch.float32
        )

In [None]:
# Import our new SpeciesLM embedder module
import sys

sys.path.append("../src")

from model.species_lm_embedder import SpeciesLMEmbedder

# Initialize the embedder for gene expression prediction (upstream model)
print("🧬 Creating embedder for gene expression prediction (upstream model)...")
embedder = SpeciesLMEmbedder.for_gene_expression(
    proxy_species="saccharomyces_cerevisiae",
    device="auto",  # Will use CUDA if available, otherwise CPU
)

print(f"SpeciesLM embedder initialized successfully!")
print(f"Model device: {embedder.device}")
print(f"Model revision: {embedder.model_revision}")
print(f"Embedding dimension: {embedder.embedding_dim}")
print(f"Species: {embedder.proxy_species}")

# Optionally, you can also create a downstream embedder for comparison
print("\n🧬 You could also create an embedder for mRNA stability (downstream model):")
print("embedder_downstream = SpeciesLMEmbedder.for_mrna_stability()")

In [None]:
import os
from pathlib import Path

# Define output directory
output_directory = Path("../embeddings")
output_directory.mkdir(exist_ok=True)

# Use our new embedder to process the FASTA file
fasta_file_path = "fasta_file.fsa"

print("Starting embedding process...")
print(f"Input FASTA: {fasta_file_path}")
print(f"Output directory: {output_directory}")

# Embed all sequences in the FASTA file using sliding windows for per-nucleotide embeddings
output_files = embedder.embed_fasta_file(
    fasta_path=fasta_file_path,
    output_dir=str(output_directory),
    mode="sliding",  # Per-nucleotide embeddings
    window_size=1000,  # 1kb windows
    stride=50,  # 50bp stride for high resolution
    layers_from=8,  # Use layers 8-12 for embeddings
    batch_size=16,  # Batch size for memory efficiency
    compress=True,  # Compress output files with gzip
)

print(f"Successfully processed {len(output_files)} sequences!")

# Display results
for seq_id, output_path in output_files.items():
    print(f"  {seq_id}: {output_path}")

    # Load and show stats for first sequence
    if seq_id == list(output_files.keys())[0]:
        embeddings = embedder.load_embeddings(output_path)
        stats = embedder.get_embedding_stats(embeddings)
        print(f"    Shape: {stats['shape']}")
        print(f"    Mean: {stats['mean']:.4f}")
        print(f"    Std: {stats['std']:.4f}")

print("Embedding process completed!")

In [None]:
# Example: Load and analyze embeddings for downstream tasks
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

# Load embeddings for the first chromosome
first_seq_id = list(output_files.keys())[0]
first_embedding_file = output_files[first_seq_id]

print(f"Loading embeddings for {first_seq_id}...")
embeddings = embedder.load_embeddings(first_embedding_file)

print(f"Embedding shape: {embeddings.shape}")
print(f"This represents {embeddings.shape[0]} nucleotide positions")
print(f"Each position has a {embeddings.shape[1]}-dimensional embedding")

# Example 1: Dimensionality reduction with PCA
print("\n1. Performing PCA for visualization...")
pca = PCA(n_components=2)
embeddings_2d = pca.fit_transform(
    embeddings[::100]
)  # Sample every 100th position for speed

plt.figure(figsize=(10, 6))
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.6, s=1)
plt.title(f"PCA of DNA Embeddings for {first_seq_id}")
plt.xlabel("First Principal Component")
plt.ylabel("Second Principal Component")
plt.show()

# Example 2: Find similar regions using clustering
print("\n2. Clustering similar genomic regions...")
kmeans = KMeans(n_clusters=5, random_state=42)
clusters = kmeans.fit_predict(embeddings[::100])

plt.figure(figsize=(12, 4))
plt.scatter(range(len(clusters)), clusters, c=clusters, cmap="viridis", alpha=0.7)
plt.title(f"Genomic Region Clusters for {first_seq_id}")
plt.xlabel("Genomic Position (sampled)")
plt.ylabel("Cluster")
plt.colorbar()
plt.show()

print(f"Found {len(set(clusters))} distinct types of genomic regions")

# Example 3: Compute embedding statistics along the chromosome
print("\n3. Computing embedding statistics along chromosome...")
window_size = 1000
step_size = 500

embedding_norms = []
positions = []

for i in range(0, len(embeddings) - window_size, step_size):
    window_embeddings = embeddings[i : i + window_size]
    avg_norm = np.linalg.norm(window_embeddings, axis=1).mean()
    embedding_norms.append(avg_norm)
    positions.append(i + window_size // 2)

plt.figure(figsize=(15, 4))
plt.plot(positions, embedding_norms)
plt.title(f"Average Embedding Norm Along {first_seq_id}")
plt.xlabel("Genomic Position (bp)")
plt.ylabel("Average Embedding Norm")
plt.show()

print("Analysis complete! The embeddings are ready for downstream tasks.")
print("\nPossible uses:")
print("- Predict gene expression levels")
print("- Identify regulatory elements")
print("- Classify genomic regions")
print("- Compare sequences across species")