This notebook describes how to use mRNA-LM (individual models) for extracting embeddings from saved (pretrained) models: CodonBERT, 5' and 3' UTRBERTS. 

- First download the pretrained models from the mentioned links (see [README](https://github.com/Sanofi-Public/mRNA-LM/tree/main?tab=readme-ov-file#pre-trained-models)). We save them inside `pretrained_models/`
- Unzip the model file: `unzip <model_file.zip>`

#### CDS Model

In [1]:
import torch
from transformers import AutoModel, AutoConfig

model_path = "./pretrained_models/codonbert"

config = AutoConfig.from_pretrained(model_path, local_files_only=True)
cds_model = AutoModel.from_pretrained(model_path, config=config, local_files_only=True)
print(cds_model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(69, 768, padding_idx=0)
    (position_embeddings): Embedding(1024, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [2]:
#CDS Tokenizer
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import BertProcessing
from transformers import PreTrainedTokenizerFast

def build_standard_codon_tokenizer():
    bases = ['A', 'U', 'G', 'C']
    special_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']
    codons = [a + b + c for a in bases for b in bases for c in bases]  # 4^3 = 64 codons
    vocab = special_tokens + codons
    vocab_dict = {tok: i for i, tok in enumerate(vocab)}

    tokenizer = Tokenizer(WordLevel(vocab=vocab_dict, unk_token='[UNK]'))
    tokenizer.add_special_tokens(special_tokens)
    tokenizer.pre_tokenizer = Whitespace()
    tokenizer.post_processor = BertProcessing(
        ("[SEP]", vocab_dict["[SEP]"]),
        ("[CLS]", vocab_dict["[CLS]"])
    )

    return PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token='[UNK]',
        sep_token='[SEP]',
        pad_token='[PAD]',
        cls_token='[CLS]',
        mask_token='[MASK]'
    )

tokenizer = build_standard_codon_tokenizer()


assert len(tokenizer.get_vocab()) == 69
assert max(tokenizer.get_vocab().values()) == 68


tokenizer.save_pretrained("./codon_tokenizer")

('./codon_tokenizer/tokenizer_config.json',
 './codon_tokenizer/special_tokens_map.json',
 './codon_tokenizer/tokenizer.json')

In [3]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import BatchEncoding
from typing import List, Union

def preprocess_CDS_sequence(seq: str) -> str:
    """Convert DNA to RNA and split into codons."""
    seq = seq.replace("T", "U")
    codons = [seq[i:i+3] for i in range(0, len(seq), 3) if len(seq[i:i+3]) == 3]
    return ' '.join(codons)

def batch_tokenize(sequences: List[str], tokenizer, max_length=1024) -> BatchEncoding:
    """Preprocess and tokenize a list of RNA sequences."""
    processed = [preprocess_CDS_sequence(seq) for seq in sequences]
    return tokenizer(
        processed,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=max_length
    )

def extract_CDS_embeddings_batch(sequences: List[str], tokenizer, model, mode='cls', batch_size=8) -> List[torch.Tensor]:
    """Extract embeddings for a list of sequences in batches."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    embeddings = []

    # Create mini-batches
    for i in range(0, len(sequences), batch_size):
        batch_seqs = sequences[i:i+batch_size]
        inputs = batch_tokenize(batch_seqs, tokenizer)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden = outputs.last_hidden_state  # [B, T, H]

            if mode == 'cls':
                batch_embeddings = last_hidden[:, 0, :]  # [B, H]
            elif mode == 'mean':
                mask = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden.size())
                masked = last_hidden * mask
                sum_embeddings = masked.sum(dim=1)
                lengths = mask.sum(dim=1).clamp(min=1e-9)
                batch_embeddings = sum_embeddings / lengths  # [B, H]
            else:
                raise ValueError("mode must be 'cls' or 'mean'")

            embeddings.extend(batch_embeddings.cpu())

    return embeddings  # List of [hidden_size] tensors

def build_CDS_embedding_dataframe(sequences: Union[List[str], None] = None,
                                   csv_path: Union[str, None] = None,
                                   column: Union[str, None] = None,
                                   mode: str = 'cls',
                                   tokenizer=None,
                                   model=None,
                                   batch_size: int = 8) -> pd.DataFrame:
    """Return a DataFrame with sequences and their embeddings."""
    if csv_path:
        df = pd.read_csv(csv_path)
        sequences = df[column].tolist()
    elif sequences is None:
        sequences = [
            "ATGCGATTTTCTAAAGTAAATGTT",
            "ATGCCCGGGAAATTAGCTAA",
            "ATGAAATTTCCCGGGTTTAA",
            "ATGATATATATATGCGCGCGCGC",
            "ATGCGCGTATATATATATAGTAG"
        ]

    embeddings = extract_CDS_embeddings_batch(sequences, tokenizer, model, mode, batch_size)
    return pd.DataFrame({
        "sequence": sequences,
        "embedding": [emb.numpy() for emb in embeddings]
    })

In [4]:
#Example
df = build_CDS_embedding_dataframe(mode="mean", tokenizer=tokenizer, model=cds_model, batch_size=4)
display(df.head())
print(f"Length of CDS embedding: {len(df['embedding'][0])}")

Unnamed: 0,sequence,embedding
0,ATGCGATTTTCTAAAGTAAATGTT,"[0.26536927, -0.4860135, 0.10935764, -0.181523..."
1,ATGCCCGGGAAATTAGCTAA,"[0.42503604, -0.41654238, 0.13327605, -0.00635..."
2,ATGAAATTTCCCGGGTTTAA,"[0.13614827, -0.14827953, 0.17040502, -0.17318..."
3,ATGATATATATATGCGCGCGCGC,"[0.5961642, -0.2065507, 0.15518683, -0.1861409..."
4,ATGCGCGTATATATATATAGTAG,"[0.49323305, -0.3121234, 0.2550852, 0.00702007..."


Length of CDS embedding: 768


#### 3' and 5' UTR models

In [5]:
#UTR tokenizer
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import BertProcessing
from transformers import PreTrainedTokenizerFast

def build_nt_tokenizer():
    special_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']
    nucleotides = list("AUGCN")  # RNA bases
    vocab = special_tokens + nucleotides
    vocab_dict = {token: idx for idx, token in enumerate(vocab)}
    
    tokenizer = Tokenizer(WordLevel(vocab=vocab_dict, unk_token='[UNK]'))
    tokenizer.add_special_tokens(special_tokens)
    tokenizer.pre_tokenizer = Whitespace()
    tokenizer.post_processor = BertProcessing(
        ("[SEP]", vocab_dict["[SEP]"]),
        ("[CLS]", vocab_dict["[CLS]"])
    )
    
    utr_tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token='[UNK]',
        sep_token='[SEP]',
        pad_token='[PAD]',
        cls_token='[CLS]',
        mask_token='[MASK]'
    )

    return utr_tokenizer

utr_tokenizer = build_nt_tokenizer()

print(len(utr_tokenizer.get_vocab()))
print(utr_tokenizer.get_vocab().values())

assert len(utr_tokenizer.get_vocab()) == 10
assert max(utr_tokenizer.get_vocab().values()) == 9

utr_tokenizer.save_pretrained("./utr_tokenizer")

10
dict_values([2, 0, 5, 6, 1, 9, 7, 3, 8, 4])


('./utr_tokenizer/tokenizer_config.json',
 './utr_tokenizer/special_tokens_map.json',
 './utr_tokenizer/tokenizer.json')

In [6]:
import pandas as pd
import torch
from typing import List, Union
from transformers import BatchEncoding

def preprocess_UTR_sequence(seq: str) -> str:
    """Convert DNA to RNA and add whitespace between bases."""
    seq = seq.replace("T", "U")
    return ' '.join(seq)

def batch_tokenize_utr(sequences: List[str], tokenizer, max_length=1024) -> BatchEncoding:
    """Preprocess and tokenize a list of UTR sequences (single nucleotide tokens)."""
    processed = [preprocess_UTR_sequence(seq) for seq in sequences]
    return tokenizer(
        processed,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=max_length
    )

def extract_UTR_embeddings_batch(sequences: List[str], tokenizer, model, mode='cls', batch_size=8, max_length=1024) -> List[torch.Tensor]:
    """Extract UTR embeddings in batches."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    embeddings = []

    for i in range(0, len(sequences), batch_size):
        batch_seqs = sequences[i:i+batch_size]
        inputs = batch_tokenize_utr(batch_seqs, tokenizer, max_length=max_length)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden = outputs.last_hidden_state  # [B, T, H]

            if mode == 'cls':
                batch_embeddings = last_hidden[:, 0, :]  # [B, H]
            elif mode == 'mean':
                mask = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden.size())
                masked = last_hidden * mask
                sum_embeddings = masked.sum(dim=1)
                lengths = mask.sum(dim=1).clamp(min=1e-9)
                batch_embeddings = sum_embeddings / lengths
            else:
                raise ValueError("mode must be 'cls' or 'mean'")

            embeddings.extend(batch_embeddings.cpu())

    return embeddings

def build_UTR_embedding_dataframe(region: str,
                                  sequences: Union[List[str], None] = None,
                                  csv_path: Union[str, None] = None,
                                  column: Union[str, None] = None,
                                  mode: str = 'cls',
                                  tokenizer=None,
                                  model=None,
                                  batch_size: int = 8) -> pd.DataFrame:
    """
    Extract UTR embeddings (5' or 3') and return a DataFrame.
    
    region: '5utr' or '3utr'
    """
    assert region in ["5utr", "3utr"], "Region must be '5utr' or '3utr'"

    max_len = 512 if region == "5utr" else 1024

    if csv_path:
        df = pd.read_csv(csv_path)
        sequences = df[column].tolist()
    elif sequences is None:
        sequences = [
            "ATGCGATTTTCTAAAGTAAATGTT",
            "ATGCCCGGGAAATTAGCTAA",
            "ATGAAATTTCCCGGGTTTAA",
            "ATGATATATATATGCGCGCGCGC",
            "ATGCGCGTATATATATATAGTAG"
        ]

    embeddings = extract_UTR_embeddings_batch(
        sequences, tokenizer, model, mode=mode, batch_size=batch_size, max_length=max_len
    )

    return pd.DataFrame({
        "sequence": sequences,
        "embedding": [emb.numpy() for emb in embeddings]
    })

In [7]:
from transformers import AutoModel, AutoTokenizer

model_path = "./pretrained_models/mrna_5utr_model_p2_cp85600_best"

config = AutoConfig.from_pretrained(model_path, local_files_only=True)
model_5utr = AutoModel.from_pretrained(model_path, config=config, local_files_only=True)
tokenizer_utr = AutoTokenizer.from_pretrained("./utr_tokenizer") 

df_5utr = build_UTR_embedding_dataframe(
    region="5utr",
    mode="mean",
    tokenizer=tokenizer_utr,
    model=model_5utr,
    batch_size=4
)

display(df_5utr)
print(len(df_5utr['embedding'][0]))

Some weights of BertModel were not initialized from the model checkpoint at ./pretrained_models/mrna_5utr_model_p2_cp85600_best and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,sequence,embedding
0,ATGCGATTTTCTAAAGTAAATGTT,"[0.35632232, 0.71432865, -0.7655516, 0.7875675..."
1,ATGCCCGGGAAATTAGCTAA,"[-0.36231306, 0.30640456, -0.574126, 1.2636681..."
2,ATGAAATTTCCCGGGTTTAA,"[0.25398535, 0.28713754, -0.19550385, 0.673237..."
3,ATGATATATATATGCGCGCGCGC,"[0.67684096, -0.078162774, -0.52981377, 0.7569..."
4,ATGCGCGTATATATATATAGTAG,"[0.3242923, 0.07677141, 0.14978021, 0.48897576..."


768


In [8]:
from transformers import AutoModel, AutoTokenizer

model_path = "./pretrained_models/mrna_3utr_model_p2_cp99900_best"

config = AutoConfig.from_pretrained(model_path, local_files_only=True)
model_3utr = AutoModel.from_pretrained(model_path, config=config, local_files_only=True)
tokenizer_utr = AutoTokenizer.from_pretrained("./utr_tokenizer") 

df_3utr = build_UTR_embedding_dataframe(
    region="3utr",
    mode="mean",
    tokenizer=tokenizer_utr,
    model=model_3utr,
    batch_size=4
)

display(df_3utr)
print(len(df_3utr['embedding'][0]))

Some weights of BertModel were not initialized from the model checkpoint at ./pretrained_models/mrna_3utr_model_p2_cp99900_best and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,sequence,embedding
0,ATGCGATTTTCTAAAGTAAATGTT,"[-0.46467862, -0.50936043, 0.36862156, 0.86998..."
1,ATGCCCGGGAAATTAGCTAA,"[-0.39854795, -0.74053574, -0.4361593, -0.2348..."
2,ATGAAATTTCCCGGGTTTAA,"[-0.037209284, -0.41736507, 0.18577996, 0.4326..."
3,ATGATATATATATGCGCGCGCGC,"[-0.43701595, -0.052947313, -0.2544344, 0.2598..."
4,ATGCGCGTATATATATATAGTAG,"[-0.8236446, -0.2006473, -0.6237884, 0.6729365..."


768
