In [1]:
#@title Install requirements. { display-mode: "form" }
# Install requirements
!pip install torch transformers sentencepiece h5py



In [2]:
#@title Set up working directories and download files/checkpoints. { display-mode: "form" }
# Create directory for storing model weights (2.3GB) and example sequences.
# Here we use the encoder-part of ProtT5-XL-U50 in half-precision (fp16) as 
# it performed best in our benchmarks (also outperforming ProtBERT-BFD).
# Also download secondary structure prediction checkpoint to show annotation extraction from embeddings
!mkdir protT5 # root directory for storing checkpoints, results etc
!mkdir protT5/protT5_checkpoint # directory holding the ProtT5 checkpoint
!mkdir protT5/sec_struct_checkpoint # directory storing the supervised classifier's checkpoint
!mkdir protT5/output # directory for storing your embeddings & predictions
!wget -nc -P protT5/ https://rostlab.org/~deepppi/example_seqs.fasta
# Huge kudos to the bio_embeddings team here! We will integrate the new encoder, half-prec ProtT5 checkpoint soon
!wget -nc -P protT5/sec_struct_checkpoint http://data.bioembeddings.com/public/embeddings/feature_models/t5/secstruct_checkpoint.pt

mkdir: cannot create directory ‘protT5’: File exists
mkdir: cannot create directory ‘protT5/protT5_checkpoint’: File exists
mkdir: cannot create directory ‘protT5/sec_struct_checkpoint’: File exists
mkdir: cannot create directory ‘protT5/output’: File exists
File ‘protT5/example_seqs.fasta’ already there; not retrieving.

File ‘protT5/sec_struct_checkpoint/secstruct_checkpoint.pt’ already there; not retrieving.



In [3]:
# In the following you can define your desired output. Current options:
# per_residue embeddings
# per_protein embeddings
# secondary structure predictions

# Replace this file with your own (multi-)FASTA
# Headers are expected to start with ">";
seq_path = "./protT5/example_seqs.fasta"

# whether to retrieve embeddings for each residue in a protein 
# --> Lx1024 matrix per protein with L being the protein's length
# as a rule of thumb: 1k proteins require around 1GB RAM/disk
per_residue = True 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# whether to retrieve per-protein embeddings 
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = False
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings

# whether to retrieve secondary structure predictions
# This can be replaced by your method after being trained on ProtT5 embeddings
sec_struct = False
sec_struct_path = "./protT5/output/ss3_preds.fasta" # file for storing predictions

# make sure that either per-residue or per-protein embeddings are stored
assert per_protein is True or per_residue is True or sec_struct is True, print(
    "Minimally, you need to active per_residue, per_protein or sec_struct. (or any combination)")


In [4]:
#@title Import dependencies and check whether GPU is available. { display-mode: "form" }
from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import time
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


In [5]:
#@title Network architecture for secondary structure prediction. { display-mode: "form" }
# Convolutional neural network (two convolutional layers) to predict secondary structure
class ConvNet( torch.nn.Module ):
    def __init__( self ):
        super(ConvNet, self).__init__()
        # This is only called "elmo_feature_extractor" for historic reason
        # CNN weights are trained on ProtT5 embeddings
        self.elmo_feature_extractor = torch.nn.Sequential(
                        torch.nn.Conv2d( 1024, 32, kernel_size=(7,1), padding=(3,0) ), # 7x32
                        torch.nn.ReLU(),
                        torch.nn.Dropout( 0.25 ),
                        )
        n_final_in = 32
        self.dssp3_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 3, kernel_size=(7,1), padding=(3,0)) # 7
                        )
        
        self.dssp8_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 8, kernel_size=(7,1), padding=(3,0))
                        )
        self.diso_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 2, kernel_size=(7,1), padding=(3,0))
                        )
        

    def forward( self, x):
        # IN: X = (B x L x F); OUT: (B x F x L, 1)
        x = x.permute(0,2,1).unsqueeze(dim=-1) 
        x         = self.elmo_feature_extractor(x) # OUT: (B x 32 x L x 1)
        d3_Yhat   = self.dssp3_classifier( x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 3)
        d8_Yhat   = self.dssp8_classifier( x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 8)
        diso_Yhat = self.diso_classifier(  x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 2)
        return d3_Yhat, d8_Yhat, diso_Yhat

In [6]:
#@title Load the checkpoint for secondary structure prediction. { display-mode: "form" }
def load_sec_struct_model():
  checkpoint_dir="./protT5/sec_struct_checkpoint/secstruct_checkpoint.pt"
  state = torch.load( checkpoint_dir )
  model = ConvNet()
  model.load_state_dict(state['state_dict'])
  model = model.eval()
  model = model.to(device)
  print('Loaded sec. struct. model from epoch: {:.1f}'.format(state['epoch']))

  return model

In [7]:
#@title Load encoder-part of ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
def get_T5_model():
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    return model, tokenizer

In [8]:
#@title Read in file in fasta format. { display-mode: "form" }
def read_fasta( fasta_path, split_char="!", id_field=0):
    '''
        Reads in fasta file containing multiple sequences.
        Split_char and id_field allow to control identifier extraction from header.
        E.g.: set split_char="|" and id_field=1 for SwissProt/UniProt Headers.
        Returns dictionary holding multiple sequences or only single 
        sequence, depending on input file.
    '''
    
    seqs = dict()
    with open( fasta_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.replace('>', '').strip().split(split_char)[id_field]
                # replace tokens that are mis-interpreted when loading h5
                uniprot_id = uniprot_id.replace("/","_").replace(".","_")
                seqs[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines, drop gaps and cast to upper-case
                seq= ''.join( line.split() ).upper().replace("-","")
                # repl. all non-standard AAs and map them to unknown/X
                seq = seq.replace('U','X').replace('Z','X').replace('O','X')
                seqs[ uniprot_id ] += seq 
    example_id=next(iter(seqs))
    print("Read {} sequences.".format(len(seqs)))
    print("Example:\n{}\n{}".format(example_id,seqs[example_id]))

    return seqs

In [9]:
#@title Generate embeddings. { display-mode: "form" }
# Generate embeddings via batch-processing
# per_residue indicates that embeddings for each residue in a protein should be returned.
# per_protein indicates that embeddings for a whole protein should be returned (average-pooling)
# max_residues gives the upper limit of residues within one batch
# max_seq_len gives the upper sequences length for applying batch-processing
# max_batch gives the upper number of sequences per batch
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, sec_struct, 
                   max_residues=4000, max_seq_len=1000, max_batch=100 ):

    if sec_struct:
      sec_struct_model = load_sec_struct_model()

    results = {"residue_embs" : dict(), 
               "protein_embs" : dict(),
               "sec_structs" : dict() 
               }

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    # seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    seq_dict = seqs.items()

    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len 
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            print("pdb seqs len",pdb_ids, seqs, seq_lens)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue

            if sec_struct: # in case you want to predict secondary structure from embeddings
              d3_Yhat, d8_Yhat, diso_Yhat = sec_struct_model(embedding_repr.last_hidden_state)


            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                if sec_struct: # get classification results
                    results["sec_structs"][identifier] = torch.max( d3_Yhat[batch_idx,:s_len], dim=1 )[1].detach().cpu().numpy().squeeze()
                if per_residue: # store per-residue embeddings (Lx1024)
                    results["residue_embs"][ identifier ] = emb.detach().cpu().numpy().squeeze()
                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(results["residue_embs"]) if per_residue else passed_time/len(results["protein_embs"])
    print('\n############# EMBEDDING STATS #############')
    print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
    print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    return results

In [10]:
# #@title Write embeddings to disk. { display-mode: "form" }
# def save_embeddings(emb_dict,out_path):
#     with h5py.File(str(out_path), "w") as hf:
#         for sequence_id, embedding in emb_dict.items():
#             # noinspection PyUnboundLocalVariable
#             # hf.create_dataset(sequence_id, data=embedding)
#             # Create a dataset with the same shape as the "embedding" variable
#             dset = hf.create_dataset(sequence_id, shape=embedding.shape, dtype='float32')
#             # Assign the "embedding" variable to the dataset
#             dset[:] = embedding
#     return None

In [11]:
import h5py

def save_embeddings(emb_dict, out_path):
    with h5py.File(str(out_path), "w") as hf:
        group = hf.create_group("embeddings")
        for i, (sequence_id, embedding) in enumerate(emb_dict.items()):
            # Create a new dataset within the group, using a numerical index as the name of the dataset
            dset = group.create_dataset(str(i), data=embedding)
            # Create a new attribute for the dataset with the original sequence id name
            dset.attrs['sequence_id'] = sequence_id
    return None







In [12]:
#@title Write predictions to disk. { display-mode: "form" }
def write_prediction_fasta(predictions, out_path):
  class_mapping = {0:"H",1:"E",2:"L"} 
  with open(out_path, 'w+') as out_f:
      out_f.write( '\n'.join( 
          [ ">{}\n{}".format( 
              seq_id, ''.join( [class_mapping[j] for j in yhat] )) 
          for seq_id, yhat in predictions.items()
          ] 
            ) )
  return None

In [13]:
# Load the encoder part of ProtT5-XL-U50 in half-precision (recommended)
model, tokenizer = get_T5_model()

In [14]:
from read_fasta import *
import sys
import os
sys.path.append('/home/pasang/all_experiment/FvFold')
import fvfold
project_path = os.path.abspath(os.path.join(fvfold.__file__, "../.."))
path = os.path.join(project_path, "data/")
filename="antibody.h5"
fasta_dir=os.path.join(path, "antibody_database/")
all_seq=read_all_seq(path,filename,fasta_dir)
all_seq

/home/pasang/all_experiment/FvFold/data/antibody.h5
Keys: <KeysViewHDF5 ['h1_range', 'h2_range', 'h3_range', 'heavy_chain_primary', 'heavy_chain_seq_len', 'id', 'l1_range', 'l2_range', 'l3_range', 'light_chain_primary', 'light_chain_seq_len', 'pairwise_geometry_mat']>
Read 2 sequences.
Example:
1baf:H
DVQLQESGPGLVKPSQSQSLTCTVTGYSITSDYAWNWIRQFPGNKLEWMGYMSYSGSTRYNPSLRSRISITRDTSKNQFFLQLKSVTTEDTATYFCARGWPLAYWGQGTQVSVS
Read 2 sequences.
Example:
1a6t:H
EVQLQQSGPDLVKPGASVKISCKASGYSFSTYYMHWVKQSHGKSLEWIGRVDPDNGGTSFNQKFKGKAILTVDKSSSTAYMELGSLTSEDSAVYYCARRDDYYFDFWGQGTSLTVS
Read 2 sequences.
Example:
1ad9:H
EIQLVQSGAEVKKPGSSVKVSCKASGYTFTDYYINWMRQAPGQGLEWIGWIDPGSGNTKYNEKFKGRATLTVDTSTNTAYMELSSLRSEDTAFYFCAREKTTYYYAMDYWGQGTLVTVS
Read 2 sequences.
Example:
1a3r:H
EVQLQQSGAELVRPGASVKLSCTTSGFNIKDIYIHWVKQRPEQGLEWIGRLDPANGYTKYDPKFQGKATITVDTSSNTAYLHLSSLTSEDTAVYYCDGYYSYYDMDYWGPGTSVTVS
Read 2 sequences.
Example:
1afv:H
QVQLQQPGSVLVRPGASVKLSCKASGYTFTSSWIHWAKQRPGQGLEWIGEIHPNSGNTNYNEKFKGKATLTVDTSSSTAYVDLSSLTSEDS

{'1baf:H': 'DVQLQESGPGLVKPSQSQSLTCTVTGYSITSDYAWNWIRQFPGNKLEWMGYMSYSGSTRYNPSLRSRISITRDTSKNQFFLQLKSVTTEDTATYFCARGWPLAYWGQGTQVSVS',
 '1baf:L': 'QIVLTQSPAIMSASPGEKVTMTCSASSSVYYMYWYQQKPGSSPRLLIYDTSNLASGVPVRFSGSGSGTSYSLTISRMEAEDAATYYCQQWSSYPPITFGVGTKLELKRA',
 '1a6t:H': 'EVQLQQSGPDLVKPGASVKISCKASGYSFSTYYMHWVKQSHGKSLEWIGRVDPDNGGTSFNQKFKGKAILTVDKSSSTAYMELGSLTSEDSAVYYCARRDDYYFDFWGQGTSLTVS',
 '1a6t:L': 'QSVLSQSPAILSASPGEKVIMTCSPSSSVSYMQWYQQKPGSSPKPWIYSTSNLASGVPGRFSGGGSGTSFSLTISGVEAEDAATYYCQQYSSHPLTFGGGTKLELKRA',
 '1ad9:H': 'EIQLVQSGAEVKKPGSSVKVSCKASGYTFTDYYINWMRQAPGQGLEWIGWIDPGSGNTKYNEKFKGRATLTVDTSTNTAYMELSSLRSEDTAFYFCAREKTTYYYAMDYWGQGTLVTVS',
 '1ad9:L': 'DIQMTQSPSTLSASVGDRVTITCRSSKSLLHSNGDTFLYWFQQKPGKAPKLLMYRMSNLASGVPSRFSGSGSGTEFTLTISSLQPDDFATYYCMQHLEYPFTFGQGTKVEVKRT',
 '1a3r:H': 'EVQLQQSGAELVRPGASVKLSCTTSGFNIKDIYIHWVKQRPEQGLEWIGRLDPANGYTKYDPKFQGKATITVDTSSNTAYLHLSSLTSEDTAVYYCDGYYSYYDMDYWGPGTSVTVS',
 '1a3r:L': 'DIVMTQSPSSLTVTTGEKVTMTCKSSQSLLNSRTQKNYLTWYQQKPGQSPKLLIYWASTRESGVPDRFTGSGSGTDFTLSISGVQA

In [15]:


# Load example fasta.
# seqs = read_fasta( seq_path )

# Compute embeddings and/or secondary structure predictions
results = get_embeddings( model, tokenizer, all_seq,
                         per_residue, per_protein, sec_struct)
per_residue_path =os.path.join(path, "per_residue_embeddings.h5")
# Store per-residue embeddings
if per_residue:
  save_embeddings(results["residue_embs"], per_residue_path)
# if per_protein:
#   save_embeddings(results["protein_embs"], per_protein_path)
# if sec_struct:
#   write_prediction_fasta(results["sec_structs"], sec_struct_path)

pdb seqs len ('1baf:H', '1baf:L', '1a6t:H', '1a6t:L', '1ad9:H', '1ad9:L', '1a3r:H', '1a3r:L', '1afv:H', '1afv:L', '1a5f:H', '1a5f:L', '1adq:H', '1adq:L', '1ay1:H', '1ay1:L', '1axs:H', '1axs:L', '1a7p:H', '1a7p:L', '1ae6:H', '1ae6:L', '1ahw:H', '1ahw:L', '1b4j:H', '1b4j:L', '1a7q:H', '1a7q:L', '1a6v:H', '1a6v:L', '1aj7:H', '1aj7:L', '1aqk:H', '1aqk:L', '1b2w:H') ('D V Q L Q E S G P G L V K P S Q S Q S L T C T V T G Y S I T S D Y A W N W I R Q F P G N K L E W M G Y M S Y S G S T R Y N P S L R S R I S I T R D T S K N Q F F L Q L K S V T T E D T A T Y F C A R G W P L A Y W G Q G T Q V S V S', 'Q I V L T Q S P A I M S A S P G E K V T M T C S A S S S V Y Y M Y W Y Q Q K P G S S P R L L I Y D T S N L A S G V P V R F S G S G S G T S Y S L T I S R M E A E D A A T Y Y C Q Q W S S Y P P I T F G V G T K L E L K R A', 'E V Q L Q Q S G P D L V K P G A S V K I S C K A S G Y S F S T Y Y M H W V K Q S H G K S L E W I G R V D P D N G G T S F N Q K F K G K A I L T V D K S S S T A Y M E L G S L T S E D S 

In [16]:
def read_embeddings(file_path):
    emb_dict = {}
    with h5py.File(file_path, 'r') as hf:
        group = hf['embeddings']
        for key in sorted(group.keys(), key=lambda x: int(x)):
            emb_dict[group[key].attrs['sequence_id']] = group[key][:]
    return emb_dict
emb_dict = read_embeddings(os.path.join(path, "per_residue_embeddings.h5"))
emb_dict.keys()

dict_keys(['1baf:H', '1baf:L', '1a6t:H', '1a6t:L', '1ad9:H', '1ad9:L', '1a3r:H', '1a3r:L', '1afv:H', '1afv:L', '1a5f:H', '1a5f:L', '1adq:H', '1adq:L', '1ay1:H', '1ay1:L', '1axs:H', '1axs:L', '1a7p:H', '1a7p:L', '1ae6:H', '1ae6:L', '1ahw:H', '1ahw:L', '1b4j:H', '1b4j:L', '1a7q:H', '1a7q:L', '1a6v:H', '1a6v:L', '1aj7:H', '1aj7:L', '1aqk:H', '1aqk:L', '1b2w:H', '1b2w:L', '1ai1:H', '1ai1:L', '1bfo:H', '1bfo:L', '1ad0:H', '1ad0:L', '1a4j:H', '1a4j:L'])