In [10]:
# The original code was written by
# https://github.com/S-Hauri/MSA_Transformer_Generator/blob/main/MSA_Transformer_sequence_generation.py

# @article{mcgee2020generative,
#   title={Generative Capacity of Probabilistic Protein Sequence Models},
#   author={McGee, Francisco and Novinger, Quentin and Levy, Ronald M and Carnevale, Vincenzo and Haldane, Allan},
#   journal={arXiv preprint arXiv:2012.02296},
#   year={2020}
# }

import sys
import numpy as np
from Bio import SeqIO
import itertools
from typing import List, Tuple
import string
from tqdm import tqdm
import random
from pprint import pprint

import esm
import torch

file_name = "seq/6M.exper_10k.seqs"
save_name = "esm_trial_gen"

torch.set_printoptions(precision=3, sci_mode=False)
torch.set_grad_enabled(False)

# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)


# python esm_generator.py seq/6M.exper.seqs gen_esm_msa 629257 4 256 32
def read_sequence(filename: str) -> Tuple[str, str]:
    """Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)


def remove_insertions(sequence: str) -> str:
    """Removes any insertions into the sequence. Needed to load aligned sequences in an MSA."""
    return sequence.translate(translation)


def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
    """Reads the first nseq sequences from an MSA file, automatically removes insertions."""
    return [
        (record.description, remove_insertions(str(record.seq)))
        for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)
    ]


def loaded_msa(msa, nseq) -> List[Tuple[str, str]]:
    """Reads the nseq sequences at constant intervals from an MSA file, automatically removes insertions."""
    N = len(msa)
    # split into chunks of approximately equal lengths (https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length)
    splits = np.array_split(range(1, N), nseq)
    output = []
    for spl in splits:
        record = msa[spl[0]]
        output.append((record["description"], remove_insertions(str(record["seq"]))))
    return output


def loaded_msa_all(msa) -> List[Tuple[str, str]]:
    """Reads the nseq sequences at constant intervals from an MSA file, automatically removes insertions."""
    N = len(msa)
    return [
        (record["description"], remove_insertions(str(record["seq"])))
        for record in itertools.islice(msa, 1, N)
    ]


# https://stackoverflow.com/questions/57237596/how-to-improve-np-random-choice-looping-efficiency
def vectorized_choice(p, n, items=None):
    s = p.cumsum(axis=1)
    r = np.random.rand(p.shape[0], n, 1)
    q = np.expand_dims(s, 1) >= r
    k = q.argmax(axis=-1)
    if items is not None:
        k = np.asarray(items)[k]
    return k

In [2]:
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

In [3]:
""" prepare Protein data """
Lines = open(file_name, "r").readlines()
seqs = [l.replace("\n", "") for l in Lines if len(l) > 3]
msa = [{"description": "noname", "seq": seq} for seq in seqs]
msa_data = loaded_msa_all(msa)

In [4]:
len(msa_data)

9999

In [5]:
# alphabet = "ACDEFGHIKLMNPQRSTVWY-"
alphabet = "ABCD"
alph_dict = {}
for i, a in enumerate(alphabet):
    alph_dict[a] = i

M = len(seqs)
L = len(seqs[0])
A = len(alphabet)
print(f"M, total sequences = {M}")
print(f"L, length of each sequence = {L}")
print(f"A, length of alphabet = {A}")

M, total sequences = 10000
L, length of each sequence = 99
A, length of alphabet = 4


In [6]:
one_hot = np.zeros((M, L, A))
for m in range(M):
    for i in range(L):
        one_hot[m, i, alph_dict[seqs[m][i]]] = 1

In [7]:
one_hot.shape

(10000, 99, 4)

In [8]:
counts = one_hot.sum(0)
indep = counts / counts.sum(-1).reshape((-1, 1))
entropy_per_pos = (-indep * np.log(indep + 1e-9)).sum(-1)
pos_seq = (entropy_per_pos.argsort() + 1).tolist()  # +1 to correct for start token


msa_transformer, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
# msa_transformer = msa_transformer.eval().cuda()
msa_batch_converter = msa_alphabet.get_batch_converter()


standard_idx = [msa_alphabet.get_idx(tok) for tok in alphabet]
L = len(msa_data[0][1]) + 1  # add start token
A = len(standard_idx)
new_seq = []

In [47]:
standard_idx

[5, 25, 23, 13]

In [39]:
msa_transformer = msa_transformer.to(mps_device)

In [27]:
n_generate = 200
n_stack = 4
n_batch = 128
n_mask = 66
save_interval = 2

for i in tqdm(range(0, n_generate, n_mask * n_stack)):
    msa_batch_data = []
    for s in range(n_stack):
        # Randomly sample one batch worth of indices, used to pull sequences
        idxs = random.sample(range(len(msa_data)), n_batch)
        msa_batch_data.append([msa_data[i] for i in idxs])

    msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(
        msa_batch_data
    )
    msa_batch_tokens = msa_batch_tokens.to(mps_device)
    new_tokens = msa_batch_tokens.clone()
    
    # mask certain proteins entirely (except start token)
    prot_idxs = np.random.randint(n_batch, size=n_mask)
    new_tokens[:, prot_idxs, 1:] = msa_alphabet.mask_idx
    pprint(new_tokens[2][3])

100%|██████████| 1/1 [00:00<00:00,  3.52it/s]

tensor([ 0, 25, 23, 23, 13, 13, 13, 23, 13, 23, 23, 13,  5, 25, 23, 13, 13,  5,
        25, 23, 13,  5, 25, 23,  5, 25, 13, 25,  5, 23, 25, 13, 23,  5,  5, 13,
        13, 23, 25, 23,  5, 23,  5, 23, 25, 23, 23,  5, 25, 23,  5, 23, 25, 23,
        23, 23, 23, 23, 25, 13, 25,  5, 13, 13, 23, 25, 25, 23, 25, 23, 25,  5,
        23,  5,  5, 23,  5, 23,  5,  5, 23, 23, 23, 25,  5, 25,  5, 13, 25, 23,
        23,  5,  5, 25, 25, 25, 23, 23, 25, 25], device='mps:0')





In [40]:
output = msa_transformer(new_tokens)

In [46]:
prot_idxs

array([ 10, 111, 104,  48,  55,  17,  74,  89,  25,  18, 105,  14,  34,
        68,  56,  75, 106,  14,  34,   9, 127,  43,  31,  88,  75,   2,
       101,  33,  81,   5,  69, 127,  59,  96,  27,  26,  74,  17,  98,
        83,  30,  45,  22,  80, 105, 114,  99,  62,  90,   5,  63,   9,
        83,  89,  46,  95, 109,  89,  42, 102,  39,  44,  20,  11,  51,
       111])

In [None]:
for i in tqdm(range(0, n_generate, n_mask * n_stack)):
    msa_batch_data = []
    for s in range(n_stack):
        idxs = random.sample(range(len(msa_data)), n_batch)
        msa_batch_data.append([msa_data[i] for i in idxs])

    msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(
        msa_batch_data
    )
    msa_batch_tokens = msa_batch_tokens.to(mps_device)

    new_tokens = msa_batch_tokens.clone()

    prot_idxs = np.random.randint(n_batch, size=n_mask)
    new_tokens[
        :, prot_idxs, 1:
    ] = msa_alphabet.mask_idx  # mask certain proteins entirely (except start token)
    
    # generate {batch_size} samples, one position at a time
    for pos in pos_seq:
        # run model and gather masked probabilities
        output = msa_transformer(new_tokens)
        probs = (
            torch.softmax(output["logits"][:, prot_idxs][:, :, pos, standard_idx], -1)
            .detach()
            .cpu()
            .numpy()
        )
        # sample random tokens based on predicted probabilities (Gibbs sampling)
        rand_res = vectorized_choice(probs.reshape((n_stack * n_mask, A)), 1).flatten()
        toks = [standard_idx[t] for t in rand_res]
        toks = torch.tensor(toks).reshape((n_stack, n_mask)).cuda()
        # replace mask with samples
        idxs_scat = torch.tensor(prot_idxs, dtype=int).cuda().expand(n_stack, -1)
        new_tokens[:, :, pos].scatter_(1, idxs_scat, toks)
    new_tokens = new_tokens.detach().cpu().numpy()
    new_seq.append(
        new_tokens[:, prot_idxs, 1:].reshape((-1, L - 1))
    )  # drop start token

    if len(new_seq) * n_stack * n_mask > save_interval:
        new_seq = np.concatenate(new_seq)
        new_strs = []
        for seq in new_seq:
            chars = [msa_alphabet.get_tok(idx) for idx in seq]
            new_strs.append("".join(chars))

        with open(save_name + ".txt", "a") as file_handler:
            for item in new_strs:
                file_handler.write("{}\n".format(item))
        new_seq = []

In [None]:
new_seq = np.concatenate(new_seq)
new_strs = []
for seq in new_seq:
    chars = [msa_alphabet.get_tok(idx) for idx in seq]
    new_strs.append("".join(chars))

with open(save_name + ".txt", "a") as file_handler:
    for item in new_strs:
        file_handler.write("{}\n".format(item))
new_seq = []