In [None]:
import os
import random
import numpy as np
import torch
import os, random, re, gc
from transformers import AutoTokenizer,\
                         AutoModelForSequenceClassification, pipeline
from bertviz import head_view, model_view

In [None]:
rand = 42
os.environ['PYTHONHASHSEED']=str(rand)
random.seed(rand)
np.random.seed(rand)
torch.manual_seed(rand)

<torch._C.Generator at 0x106444c50>

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
codon_mdl_path = "adibvafa/CodonTransformer"
codon_tok = AutoTokenizer.from_pretrained(codon_mdl_path)
codon_mdl = AutoModelForSequenceClassification.from_pretrained(
    codon_mdl_path, output_attentions=True
)
codon_mdl.to(device)
codon_mdl.eval()

Some weights of BigBirdForSequenceClassification were not initialized from the model checkpoint at adibvafa/CodonTransformer and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BigBirdForSequenceClassification(
  (bert): BigBirdModel(
    (embeddings): BigBirdEmbeddings(
      (word_embeddings): Embedding(90, 768, padding_idx=0)
      (position_embeddings): Embedding(4096, 768)
      (token_type_embeddings): Embedding(164, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BigBirdEncoder(
      (layer): ModuleList(
        (0-11): 12 x BigBirdLayer(
          (attention): BigBirdAttention(
            (self): BigBirdBlockSparseAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): BigBirdSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

In [None]:
num_layers = codon_mdl.config.num_hidden_layers
num_attention_heads = codon_mdl.config.num_attention_heads

print(f"The model has {num_layers} layers.")
print(f"Each layer has {num_attention_heads} attention heads.")

The model has 12 layers.
Each layer has 12 attention heads.


In [None]:
def clear_gpu_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
def view_attention(tokenizer, model, sequence, view='model'):
    # Encode sequence with tokenizer
    inputs = tokenizer.encode_plus(sequence, return_tensors='pt')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    input_ids = inputs['input_ids'].to(device)

    # Get attention weights from model
    outputs = model(input_ids)
    attention = outputs[-1]

    # Get tokens for visualization
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

    # BertViz visualizers (show last layer)
    if view == 'head':
        head_view(attention, tokens, -1)
    elif view == 'model':
        model_view(attention, tokens, -1)
    del attention
    del tokens
    clear_gpu_cache()

In [None]:
print(codon_tok.vocab)
dir(codon_tok)

{'i_ata': 38, 'm_unk': 15, 'y_tac': 75, 'q_cag': 44, 'r_cgc': 51, 'r_cgt': 53, '__tag': 76, 'a_gcg': 64, 'i_unk': 12, '__taa': 74, 'd_unk': 7, 'p_ccg': 48, 't_unk': 21, 'l_tta': 86, 'd_gat': 61, 'k_aag': 28, '[PAD]': 3, 'g_ggg': 68, 'g_ggt': 69, 'p_unk': 17, 'e_gaa': 58, 'w_tgg': 84, 'v_gtc': 71, 'y_unk': 24, 's_tct': 81, '[UNK]': 0, 'r_cga': 50, 'e_unk': 8, 'g_gga': 66, 'e_gag': 60, 's_agt': 37, 't_acg': 32, 's_tcg': 80, 'l_unk': 14, 'a_gcc': 63, 'a_gca': 62, 'n_aac': 27, 'r_cgg': 52, 'p_ccc': 47, 'c_tgt': 85, '[CLS]': 1, 'v_gtt': 73, 'g_unk': 10, 'f_ttc': 87, 'p_cca': 46, 's_tcc': 79, 'l_ctt': 57, 't_acc': 31, '[MASK]': 4, '__unk': 25, 'y_tat': 77, 'q_unk': 18, 't_aca': 30, 'g_ggc': 67, 'p_cct': 49, 'r_aga': 34, 'q_caa': 42, 'm_atg': 40, 'l_ctc': 55, 'v_unk': 22, 'n_unk': 16, 'k_aaa': 26, 'v_gta': 70, 'f_ttt': 89, 'f_unk': 9, 'h_unk': 11, 'c_unk': 6, 'a_gct': 65, 'r_agg': 36, 't_act': 33, 'd_gac': 59, 'r_unk': 19, 'l_cta': 54, 'l_ctg': 56, 'a_unk': 5, 'k_unk': 13, 'h_cac': 43, '__tga

In [None]:
codon_sequence = "m_atg a_gcc e_gaa __taa"
view_attention(codon_tok, codon_mdl, codon_sequence, view='model')

['m_atg', 'a_gcc', 'e_gaa', '__taa']


<IPython.core.display.Javascript object>

In [None]:
from typing import Union, List
from transformers import (
    AutoTokenizer,
    BatchEncoding,
    BigBirdConfig,
    BigBirdForMaskedLM,
    PreTrainedTokenizerFast,
)

from CodonTransformer.CodonPrediction import predict_dna_sequence

from CodonTransformer.CodonUtils import (
    AMINO_ACID_TO_INDEX,
    INDEX2TOKEN,
    NUM_ORGANISMS,
    ORGANISM2ID,
    TOKEN2INDEX,
    DNASequencePrediction,
)
from CodonTransformer.CodonJupyter import format_model_output

In [None]:

def predict_dna_sequencen_annotated(
    protein: str,
    organism: Union[int, str],
    device: torch.device,
    tokenizer: Union[str, PreTrainedTokenizerFast] = None,
    model: Union[str, torch.nn.Module] = None,
    attention_type: str = "original_full",
    deterministic: bool = True,
    temperature: float = 0.2,
    top_p: float = 0.95,
    num_sequences: int = 1,
    match_protein: bool = False,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
    print(f"Predicting DNA sequence for protein: {protein}, organism: {organism}")

tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device)

# Set your input data
protein = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
organism = "Escherichia coli general"

result = predict_dna_sequencen_annotated(
    protein="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG",
    organism="Escherichia coli general",
    device=device,
    tokenizer=tokenizer,
    model=model
)
print(result)

# Predict with CodonTransformer
output = predict_dna_sequence(
    protein=protein,
    organism=organism,
    device=device,
    tokenizer=tokenizer,
    model=model,
    attention_type="original_full",
    deterministic=True
)
print(format_model_output(output))