In [None]:
# NOTE: You may need to run this twice due to a pip dependency conflict
%pip install https://github.com/braceal/cpe.git

In [1]:


import functools
import time
import warnings
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from contextlib import ExitStack
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import os
import h5py
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import torch
import numpy.typing as npt
from Bio import SeqIO  # type: ignore[import]

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import BatchEncoding, PreTrainedTokenizerFast, BertForMaskedLM
from cpe.utils import (
    gc_content,
    get_label_dict,
    parse_sequence_labels,
    preprocess_data,
    read_fasta,
    read_fasta_only_seq
)
from cpe.dataset import GenSLMColatorForLanguageModeling, FastaDataset, llm_inference
from all_cluster_visuzlization import PlotClustersData


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
CODON_TO_CHAR = {
    "TCG": "A",
    "GCA": "B",
    "CTT": "C",
    "ATT": "D",
    "TTA": "E",
    "GGG": "F",
    "CGT": "G",
    "TAA": "H",
    "AAA": "I",
    "CTC": "J",
    "AGT": "K",
    "CCA": "L",
    "TGT": "M",
    "GCC": "N",
    "GTT": "O",
    "ATA": "P",
    "TAC": "Q",
    "TTT": "R",
    "TGC": "S",
    "CAC": "T",
    "ACG": "U",
    "CCC": "V",
    "ATC": "W",
    "CAT": "X",
    "AGA": "Y",
    "GAG": "Z",
    "GTG": "a",
    "GGT": "b",
    "GCT": "c",
    "TTC": "d",
    "AAC": "e",
    "TAT": "f",
    "GTA": "g",
    "CCG": "h",
    "ACA": "i",
    "CGA": "j",
    "TAG": "k",
    "CTG": "l",
    "GGA": "m",
    "ATG": "n",
    "TCT": "o",
    "CGG": "p",
    "GAT": "q",
    "ACC": "r",
    "GAC": "s",
    "GTC": "t",
    "TGG": "u",
    "CCT": "v",
    "GAA": "w",
    "TCA": "x",
    "CAA": "y",
    "AAT": "z",
    "ACT": "0",
    "GCG": "1",
    "GGC": "2",
    "CTA": "3",
    "AAG": "4",
    "AGG": "5",
    "CAG": "6",
    "AGC": "7",
    "CGC": "8",
    "TTG": "9",
    "TCC": "!",
    "TGA": "@",
    "XXX": "*",
}

In [8]:
# enter the fasta filepath to a fasta path:
fasta_path = "/home/couchbucks/Documents/saketh/cpe/data/datasets/mdh/mdh_natural_dataset.fasta"

# enter the checkpoint to the tokenizer:
tokenizer_path = "/home/couchbucks/Documents/saketh/cpe/cpe/cpe_tokenizer_retrained_3000"


model_checkpoint = "/home/couchbucks/Documents/saketh/cpe/cpe/checkpoints/bpe/cpe_tokenizer/bert/checkpoint-34000"


In [10]:
if os.path.isfile(Path(tokenizer_path)):
    # These are for the .json files
    tokenizer = PreTrainedTokenizerFast.from_pretrained(
        pretrained_model_name_or_path=tokenizer_path
    )
else:
    # These are for the bpe tokenizers
    tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
    
special_tokens = {
        "unk_token": "[UNK]",
        "cls_token": "[CLS]",
        "sep_token": "[SEP]",
        "pad_token": "[PAD]",
        "mask_token": "[MASK]",
        "bos_token": "[BOS]",
        "eos_token": "[EOS]",
    }
    # for some reason, we need to add the special tokens even though they are in the json file
tokenizer.add_special_tokens(special_tokens)
model = BertForMaskedLM.from_pretrained(model_checkpoint)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(3000, 240, padding_idx=3)
      (position_embeddings): Embedding(1024, 240)
      (token_type_embeddings): Embedding(2, 240)
      (LayerNorm): LayerNorm((240,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=240, out_features=240, bias=True)
              (key): Linear(in_features=240, out_features=240, bias=True)
              (value): Linear(in_features=240, out_features=240, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=240, out_features=240, bias=True)
              (LayerNorm): LayerNorm((240,), eps=1e-12, elementwise_aff

In [11]:
prompt = tokenizer.encode("n", return_tensors="pt").to(device) # "n" is "ATG" in cpe language

tokens = model.generate(
    prompt,
    max_length=50,  # Increase this to generate longer sequences # 300 means 900 base length sequences
    min_length=10,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    num_return_sequences=128,  # Change the number of sequences to generate
    remove_invalid_values=True,
    use_cache=True,
    pad_token_id=tokenizer.encode("[PAD]")[0],
    temperature=1.0,
)

generated_sequences = tokenizer.batch_decode(tokens, skip_special_tokens=True)

KeyboardInterrupt: 

In [None]:
#TODO: ImportError: cannot import name 'GenSLMColatorForLanguageModeling' from 'dataset' (/home/couchbucks/Documents/saketh/cpe/cpe/dataset.py)


In [None]:
embeddings, _, _ = llm_inference(
    tokenizer_path,
    model_checkpoint,
    fasta_path,
    return_codon = False,
    return_aminoacid = False,
    batch_size = 128,
    fasta_contains_aminoacid = False,
)

In [115]:
tsne_embeddings = TSNE(n_components=2).fit_transform(embeddings)


In [150]:
embedding_visualization = PlotClustersData(sequences=sequences, tsne_hidden_states=tsne_embeddings, labels=[], label_dict={}, tokenizer_type="CPE Tokenizer")

In [None]:
# plotting all sequences colored with gc content
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_gc_content()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with sequence length
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_seq_len()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with molecular weight
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_molecular_weight()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with isoelectric point
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_isoelectric_point()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with aromaticity
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_aromaticity()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with instability index
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_instability_index()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
# plotting all sequences colored with flexibility
(
    plot_df_separate,
    hue_separate,
    plt_title,
) = embedding_visualization.plot_flexibility()

embedding_visualization.plot_clusters(plot_df_separate, hue_separate, plt_title)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    embeddings, labels, stratify=labels, random_state=1
)
clf = MLPClassifier(random_state=1, max_iter=300).fit(X_train, y_train)
print(f"MLP model train accuracy: {clf.score(X_train, y_train)}")
print(f"MLP model test accuracy: {clf.score(X_test, y_test)}")