In [79]:
from pathlib import Path
from collections import defaultdict, Counter
from genslm.utils import read_fasta_only_seq

In [80]:
CODON_MAP = {
    "TCG": 0,
    "GCA": 1,
    "CTT": 2,
    "ATT": 3,
    "TTA": 4,
    "GGG": 5,
    "CGT": 6,
    "TAA": 7,
    "AAA": 8,
    "CTC": 9,
    "AGT": 10,
    "CCA": 11,
    "TGT": 12,
    "GCC": 13,
    "GTT": 14,
    "ATA": 15,
    "TAC": 16,
    "TTT": 17,
    "TGC": 18,
    "CAC": 19,
    "ACG": 20,
    "CCC": 21,
    "ATC": 22,
    "CAT": 23,
    "AGA": 24,
    "GAG": 25,
    "GTG": 26,
    "GGT": 27,
    "GCT": 28,
    "TTC": 29,
    "AAC": 30,
    "TAT": 31,
    "GTA": 32,
    "CCG": 33,
    "ACA": 34,
    "CGA": 35,
    "TAG": 36,
    "CTG": 37,
    "GGA": 38,
    "ATG": 39,
    "TCT": 40,
    "CGG": 41,
    "GAT": 42,
    "ACC": 43,
    "GAC": 44,
    "GTC": 45,
    "TGG": 46,
    "CCT": 47,
    "GAA": 48,
    "TCA": 49,
    "CAA": 50,
    "AAT": 51,
    "ACT": 52,
    "GCG": 53,
    "GGC": 54,
    "CTA": 55,
    "AAG": 56,
    "AGG": 57,
    "CAG": 58,
    "AGC": 59,
    "CGC": 60,
    "TTG": 61,
    "TCC": 62,
    "TGA": 63,
}

# Assign a unique character to each codon so that we can use it as an
# input token to a BPE tokenizer. This implements a codon-pair encoding.
CODON_CHAR = {
    "TCG": "A",
    "GCA": "B",
    "CTT": "C",
    "ATT": "D",
    "TTA": "E",
    "GGG": "F",
    "CGT": "G",
    "TAA": "H",
    "AAA": "I",
    "CTC": "J",
    "AGT": "K",
    "CCA": "L",
    "TGT": "M",
    "GCC": "N",
    "GTT": "O",
    "ATA": "P",
    "TAC": "Q",
    "TTT": "R",
    "TGC": "S",
    "CAC": "T",
    "ACG": "U",
    "CCC": "V",
    "ATC": "W",
    "CAT": "X",
    "AGA": "Y",
    "GAG": "Z",
    "GTG": "a",
    "GGT": "b",
    "GCT": "c",
    "TTC": "d",
    "AAC": "e",
    "TAT": "f",
    "GTA": "g",
    "CCG": "h",
    "ACA": "i",
    "CGA": "j",
    "TAG": "k",
    "CTG": "l",
    "GGA": "m",
    "ATG": "n",
    "TCT": "o",
    "CGG": "p",
    "GAT": "q",
    "ACC": "r",
    "GAC": "s",
    "GTC": "t",
    "TGG": "u",
    "CCT": "v",
    "GAA": "w",
    "TCA": "x",
    "CAA": "y",
    "AAT": "z",
    "ACT": "0",
    "GCG": "1",
    "GGC": "2",
    "CTA": "3",
    "AAG": "4",
    "AGG": "5",
    "CAG": "6",
    "AGC": "7",
    "CGC": "8",
    "TTG": "9",
    "TCC": "!",
    "TGA": "@",
    "XXX": "*"
}


def group_by_kmer(seq: str, kmer: int) -> str:
    return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()

In [81]:
data_path = Path("/Users/abrace/Downloads/generated_mdh-gene_seqs.fasta")

sequences = read_fasta_only_seq(data_path)
sequences = [group_by_kmer(seq, 3) for seq in sequences]  # Group by 3-mer (codon)
sequences = [
    "".join(CODON_CHAR[codon] for codon in seq.split()) for seq in sequences
]  # Convert codons to unique chars

In [82]:
sequences

['nBjVIWB9WFNb6Wb2UlBX9B1WIwJFstildqDNw2rhymI!lqWBwNbh!w2Rq1Ul4FUzqfNqWBq1qOSWgUB2ahp4hFn78qqClbDel4OnIAa2ZFDNNX1Vs1RaWSWrehlq1ntu1l8wR!2lhTN4at2nNbaJqA18d8TdlAlwdetAn8stU1da92FX2qUnahla8f!UOBbDVlVqlW1n2uiU6w4lqbDO68rGqbb1wDablC2UFA1RQBV1r71WwnNZNfJzs688JlvS1Nfaq2Bl2leFnfa2ahUWW21zFWwIaaqW5neqqZ61ndeIAOq1a429twNS42Ws7AlBH',
 'n4WitP2c2zgbN0iclCJ1y46JFqWOJasttZ2WV6FIBJqnX6!bhtCbRsU6ar2!zqQcBrBsxxWttWU1FC15Yh2nKpAsll6BeNNWapxOOwZBt4fAhqrDJCOaUzLlqanrfCNJu47B6xYvpad2n!FOlq0N8dYnRWBwwlsWx1cstZ1laWmKX2qNnlV9hpf0UtK2WLDr6lCAZs4WeNlP48r6sbbBZWtpXlI72K1dfchFBoO1pnaZBPNesKI5aJhB!1llq2wfbEIs0Ql2ahaIEmcmmDwZWWZlqJr!sZ4FC9p4xBwFap4igNwJVgH',
 'n4rva8ONtrbcN26WKf!lldpDc!2ZnJ2Is6VaWJylJZWrhNn41JZbtanZlsqMNdhllNbWZUNsqVztNd4sr6a1llaFNnh84NFnZ8!qJlzNe22Wd4V62ZNJ7pKNp4sa4alat2ehNerzNlW168zN1qlIvZ6dSNna8lsTe8NDA6WNe4t2N1t7sa4eaWWubeT!!r6dhsJQTNraN24!1JF1as6suQZeZQWVpaN48b1ZWWZ1pFNAANN!NNeNNWsXa8sunJ2Uhq2sut!nFta!q2!Q2ahZ2ltQ2dharS4s2Zu7Wt6bJsWsZdA824nsNU16wJrZZ8s1aN4J2ltk',
 'n744Vt8OcOr2BBbyWbQNllRGDNo

In [76]:
sequences = sequences[:2]  # Take a small set for testing

In [77]:
# Compute the count of each codon in the dataset
codon_freqs = Counter("".join(sequences))
codon_freqs

Counter({'l': 35,
         'W': 30,
         'q': 29,
         '2': 28,
         'a': 27,
         '1': 26,
         'B': 23,
         'N': 21,
         'n': 20,
         'w': 18,
         'F': 16,
         'b': 16,
         's': 16,
         't': 16,
         '4': 16,
         'h': 15,
         'U': 14,
         '6': 13,
         'J': 13,
         'Z': 12,
         'I': 11,
         'r': 11,
         'O': 11,
         '8': 11,
         'f': 10,
         'A': 10,
         'd': 9,
         'D': 9,
         '!': 9,
         'p': 9,
         'C': 9,
         'e': 9,
         'V': 8,
         '9': 6,
         'X': 6,
         'R': 6,
         'z': 6,
         'c': 6,
         'x': 6,
         'i': 5,
         'm': 5,
         '7': 5,
         'K': 5,
         'S': 4,
         'g': 4,
         '0': 4,
         'u': 3,
         'Q': 3,
         '5': 3,
         'P': 3,
         'Y': 3,
         'y': 2,
         'T': 2,
         'v': 2,
         'H': 2,
         'L': 2,
         'E': 2,
     

In [None]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in codon_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

In [78]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer