In [1]:
import torch
import numpy as np
from Bio import SeqIO
import h5py
import re
from tqdm import tqdm
import os


### load config

In [2]:
import yaml
config = yaml.safe_load(open("embed.yaml")) # load config file

In [3]:
pLM = config['pLM']

data = config['data']

data_dir = config['data_dir']
label_dir = config['label_dir']
embed_save_dir = config['embed_save_dir']
label_save_dir = config['label_save_dir']
max_len = config['truncate']

In [7]:
if os.path.exists(embed_save_dir) == False:
    os.mkdir(embed_save_dir)
if os.path.exists(label_save_dir) == False:
    os.mkdir(label_save_dir)

### helper functions

In [13]:
def read_fasta(fasta_path, split_char="|", id_field=1):
    '''
        Reads in fasta file containing multiple sequences.
        Split_char and id_field allow to control identifier extraction from header.
        E.g.: set split_char="|" and id_field=1 for SwissProt/UniProt Headers.
        Returns dictionary holding multiple sequences or only single 
        sequence, depending on input file.
    '''

    seqs = dict()
    with open(fasta_path, 'r') as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.replace(
                    '>', '').strip().split(split_char)[id_field]
                # replace tokens that are mis-interpreted when loading h5
                uniprot_id = uniprot_id.replace("/", "_").replace(".", "_")
                seqs[uniprot_id] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines, drop gaps and cast to upper-case
                seq = ''.join(line.split()).upper().replace("-", "")
                # repl. all non-standard AAs and map them to unknown/X
                seq = seq.replace('U', 'X').replace('Z', 'X').replace('O', 'X')
                seqs[uniprot_id] += seq
    example_id = next(iter(seqs))
    print("Read {} sequences.".format(len(seqs)))
    print("Example:\n{}\n{}".format(example_id, seqs[example_id]))

    return seqs


In [33]:
def prep_data(fasta_path, label_path, max_len=512, verbose = False):
  print("==========prepare data===========")
  labels = np.load(label_path)
  acc_ls = []
  seq_ls = []
  label_ls = []

  for seq_record in SeqIO.parse(fasta_path, "fasta"):
    # retrieve the sequence accessions, as keys
    acc = seq_record.id.split('|')[1]
    
    # retrieve the sequence, replace non-standard amino acids with X
    seq = str(seq_record.seq)
    seq = re.sub(r"[UZOB]", "X", seq)
    label = labels[acc]
    
    # if sequence is shorter than max_len, add the accession, sequence and label directly to the lists
    if len(seq) <= max_len:
      if label.shape != (1,):
        assert len(seq) == label.shape[1]
      temp_str = seq[:3] + "..." + seq[-3:] if len(seq) > 6 else seq
      if verbose:
        print(
            f"adding {acc}: {temp_str}\nseq len: {len(seq)}, label shape: {label.shape}")
      acc_ls.append(acc)
      seq_ls.append(seq)
      label_ls.append(label)

    else:
      if verbose:
        print(
            f"splitting {acc}: {seq[:3]}...{seq[-3:]}\nseq len: {len(seq)}, label shape: {label.shape}")
        
      # if sequence is longer than max_len, split it into chunks of max_len
      for i in range(0, len(seq_record.seq), max_len):
        if label.shape != (1,):
          assert len(seq[i:i+max_len]) == label[:, i:i+max_len].shape[1]
          if verbose:
            print(
                f"adding trunc {i}\nseq len: {len(seq[i:i+max_len])}, label shape: {label[:, i:i+max_len].shape}")
        else:
          if verbose:
            print(
                f"adding trunc {i}\nseq len: {len(seq[i:i+max_len])}, label shape: {label.shape}")
            
            
        # add the revised accession, truncated sequence and truncated label to the lists
        acc_ls.append(acc+"_"+str(i))
        seq_ls.append(seq[i:i+max_len])
        if label.shape != (1,):
          label_ls.append(label[:, i:i+max_len])
        else:
          label_ls.append(label)

  assert len(acc_ls) == len(seq_ls) == len(label_ls)
  return acc_ls, seq_ls, label_ls


In [4]:
def save_embeddings(emb_dict, out_path):
    with h5py.File(str(out_path), "w", track_order=True) as hf:
        for sequence_id, embedding in emb_dict.items():
            hf.create_dataset(sequence_id, data=embedding)
    return None

In [5]:
def retrive_embedding(path):
    embeddings_in = path
    ls = []
    with h5py.File(embeddings_in, 'r', track_order=True) as f_in:

        for key, embedding in f_in["/"].items():
            print(key)
            ls.append((key, embedding[()]))
        return ls


### load language model

In [6]:
# load pLM
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if pLM == "Ankh":
    import ankh
    model, tokenizer = ankh.load_large_model()
    model.eval()
    model.to(device=device)
elif pLM == "ProtT5":
    from transformers import T5EncoderModel, T5Tokenizer
    tokenizer = T5Tokenizer.from_pretrained(
        'Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    model = T5EncoderModel.from_pretrained(
        "Rostlab/prot_t5_xl_half_uniref50-enc").to(device)

    model.full() if device == 'cpu' else model.half()

Some weights of the model checkpoint at ElnaggarLab/ankh-large were not used when initializing T5EncoderModel: ['decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.16.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.embed_tokens.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.0.SelfAttention.v.weight', 'decoder.block.8.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.13.layer.2.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.15.layer.1.EncDecAttention.o.w

T5EncoderModel(
  (shared): Embedding(144, 1536)
  (encoder): T5Stack(
    (embed_tokens): Embedding(144, 1536)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1536, out_features=1024, bias=False)
              (k): Linear(in_features=1536, out_features=1024, bias=False)
              (v): Linear(in_features=1536, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1536, bias=False)
              (relative_attention_bias): Embedding(64, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1536, out_features=3840, bias=False)
              (wi_1): Linear(in_features=1536, out_features=3840, bias=False)
              (wo): Lin

In [28]:
def embed_batch_Ankh(seq_ls, model, tokenizer, device, shift_left=0, shift_right=-1):
    # sequence embedding with Ankh
    seq_ls = [list(seq) for seq in seq_ls]
    embeddings = []
    with torch.no_grad():
        for seq in tqdm(seq_ls):
            output = tokenizer.batch_encode_plus([seq],
                                                      add_special_tokens=True,
                                                      padding=True,
                                                      is_split_into_words=True,
                                                      return_tensors="pt")
            embedding = model(input_ids=output['input_ids'].to(
                device=device)).last_hidden_state
            embedding = embedding[0].detach().cpu().numpy()[
                shift_left:shift_right]
            embeddings.append(embedding)
    return embeddings


In [None]:
def embed_batch_protT5(seq_ls, model, tokenizer, device, shift_left=0, shift_right=-1):
    # sequence embedding with ProtT5
    seq_ls = [[" ".join(list(seq))]
              for seq in seq_ls]
    embeddings = []
    with torch.no_grad():
        for seq in tqdm(seq_ls):
            #print(seq)
            ids = tokenizer.batch_encode_plus(
                seq, add_special_tokens=True, padding="longest")

            input_ids = torch.tensor(ids['input_ids']).to(device)
            attention_mask = torch.tensor(ids['attention_mask']).to(device)

            # generate embeddings

            embedding_rpr = model(input_ids=input_ids,
                                  attention_mask=attention_mask)
            embedding = embedding_rpr.last_hidden_state[0].detach().cpu().numpy()[
                shift_left:shift_right]
            embeddings.append(embedding)
    return embeddings


### truncate & embed

In [34]:
# Perform truncation, embedding, and save

for file_name in data:
    acc_ls, seq_ls, label_ls = prep_data(
        f"{data_dir}{file_name}.fasta", f"{label_dir}{file_name}.npz")
    np.savez(f'{label_save_dir}truncated_{file_name}.npz',
             **dict(zip(acc_ls, label_ls)))
    
    if pLM == "Ankh":
        embeddings = embed_batch_Ankh(seq_ls, model, tokenizer,
                             device, shift_left=0, shift_right=-1)
    elif pLM == "ProtT5":
        embed_batch_protT5(seq_ls, model, tokenizer, device)
    
    save_embeddings(dict(zip(acc_ls, embeddings)),
                    f"{embed_save_dir}{file_name}_embeddings.h5")




100%|██████████| 3015/3015 [13:08<00:00,  3.83it/s]




100%|██████████| 3064/3064 [14:16<00:00,  3.58it/s]




100%|██████████| 3188/3188 [14:54<00:00,  3.57it/s]




100%|██████████| 2986/2986 [13:37<00:00,  3.65it/s]




100%|██████████| 3115/3115 [15:18<00:00,  3.39it/s]




100%|██████████| 3045/3045 [17:00<00:00,  2.98it/s]




100%|██████████| 3051/3051 [15:31<00:00,  3.28it/s]




100%|██████████| 3117/3117 [15:28<00:00,  3.36it/s]




100%|██████████| 3131/3131 [14:48<00:00,  3.52it/s]




100%|██████████| 3084/3084 [14:16<00:00,  3.60it/s]




100%|██████████| 3093/3093 [14:18<00:00,  3.60it/s]




100%|██████████| 3109/3109 [14:31<00:00,  3.57it/s]




100%|██████████| 3053/3053 [14:08<00:00,  3.60it/s]


### test

In [6]:
embedding_ = retrive_embedding(f"{embed_save_dir}{data[0]}_embeddings.h5")


A0A0H3KB22
A0A1C7D1B7
A0NLY7
A0Q5Y3_0
A0Q5Y3_512
A0Q5Y3_1024
A0Q5Y3_1536
A0QZY0
A3DC27_0
A3DC27_512
A4XF23
A5TYT6_0
A5TYT6_512
A9CK16
B0T0B1
B3PDB1
B5HDJ6
C6D9S0
D0E8I5
D0EM77
D6WI29
E9Q555_0
E9Q555_512
E9Q555_1024
E9Q555_1536
E9Q555_2048
E9Q555_2560
E9Q555_3072
E9Q555_3584
E9Q555_4096
E9Q555_4608
E9Q555_5120
F8GV06
I0DF35_0
I0DF35_512
I0DF35_1024
I0DF35_1536
I0DF35_2048
I3LM39
I6XD65
K5B7F3
O05581
O08498
O13833
O14607_0
O14607_512
O14607_1024
O15550_0
O15550_512
O15550_1024
O15865_0
O15865_512
O31526_0
O31526_512
O31527_0
O31527_512
O34714
O43819
O53512
O55023
O66188
O66990
O67050
O75688
O75874
O83774
O87198
O92956_0
O92956_512
O92956_1024
O92956_1536
O94753
O95251_0
O95251_512
O95989
P00727_0
P00727_512
P00971
P03265_0
P03265_512
P03354_0
P03354_512
P03354_1024
P03354_1536
P03956
P04129
P05186_0
P05186_512
P05187_0
P05187_512
P05806_0
P05806_512
P06786_0
P06786_512
P06786_1024
P06988
P06996
P07379_0
P07379_512
P07598
P07846
P07884
P07902
P08200
P08254
P09148
P09958_0
P09958_512
P0A8M