https://www.phind.com/search?cache=841083d9-ba45-4824-924d-6a8f7bb780e5

In [14]:
import esm

In [15]:
!pip install "fair-esm[esmfold]"
# OpenFold and its remaining dependency
!pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
#!pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Collecting dllogger@ git+https://github.com/NVIDIA/dllogger.git
  Cloning https://github.com/NVIDIA/dllogger.git to /tmp/pip-install-6migieyi/dllogger_98b9a48b687f440eab7eb150663fdfc3
  Running command git clone --filter=blob:none --quiet https://github.com/NVIDIA/dllogger.git /tmp/pip-install-6migieyi/dllogger_98b9a48b687f440eab7eb150663fdfc3
  Resolved https://github.com/NVIDIA/dllogger.git to commit 0540a43971f4a8a16693a9de9de73c1072020769
  Preparing metadata (setup.py) ... [?25ldone

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available:

In [52]:
import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np

from transformers import AutoModel, AutoTokenizer
from transformers import pipeline

# Load ESM model
model_name = 'facebook/esm2_t36_3B_UR50D'
esm_tokenizer = AutoTokenizer.from_pretrained(model_name)
esm_tokenizer.pad_token = esm_tokenizer.eos_token
esm_model = AutoModel.from_pretrained(model_name)

# Load ESMFold model 
esmfold_model = esm.pretrained.esmfold_v1()
esmfold_model = esmfold_model.eval().cuda()     

class ProteinDataset(Dataset):
    def __init__(self, sequences):   
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        
        # Get ESM embedding
        with torch.no_grad():    
            expected_length = min(esm_model.config.max_position_embeddings, esm_tokenizer.model_max_length)

            #tokenized_prompt = tokenizer(prompt_correct_answer, padding="longest", truncation=True, max_length=None)
            tokenized_prompt = esm_tokenizer(sequence, padding="max_length", truncation=True, max_length=expected_length)

            # Convert tokenized sequences to arrays of integers
            input_ids = np.array(tokenized_prompt["input_ids"], dtype=np.int32)
            attention_mask = np.array(tokenized_prompt["attention_mask"], dtype=np.int32)        
            
            # Convert to tensors
            input_ids = torch.tensor(input_ids)
            attention_mask = torch.tensor(attention_mask)

            # Check if dimension is 1, and add an extra dimension if it is
            if input_ids.ndim == 1:
                input_ids = input_ids[np.newaxis, :]

            if attention_mask.ndim == 1:
                attention_mask = attention_mask[np.newaxis, :]                
                
            esm_emb = esm_model(
                input_ids, 
                attention_mask   
            )
        
        # Get ESMFold embedding
        with torch.no_grad():
            esmfold_emb = esmfold_model.infer_pdb(sequence)
        
        # Form pairs of ESM and ESMFold embeddings 
        emb_pair = (esm_emb, esmfold_emb)
        
        return emb_pair

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at facebook/esm2_t36_3B_UR50D were not used when initializing EsmModel: ['lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [84]:
import os 
import tempfile
from zipfile import ZipFile

class GOEmbeddingsDataset(torch.utils.data.Dataset):
    def __init__(self, go_id2vec_dict):
        self.embeddings = list(go_id2vec_dict.items())
        
    def __len__(self):
        return len(self.embeddings)
      
    def __getitem__(self, index): 
        go_id, vec = self.embeddings[index]
        vec = np.array([vec])
        return go_id, torch.from_numpy(vec)

def load_go_embeddings():
    go_id2vec_dict = {}
    
    path = "/home/ubuntu/proteinbind/GO_pretrained/go_embeddings_20k.zip"

    # Unzip the zip folder to a temporary directory       
    zip_file = ZipFile(path, "r")
    zip_file.extractall(tempfile.gettempdir())        
    
    zip_path = tempfile.gettempdir() + "/go_embeddings" # Temporary directory path 
    #print(f"zip_path = {zip_path}")
    
    # Get list of .npy files
    npy_files = [f for f in os.listdir(zip_path) if f.endswith('.npy')]
    #print(f"npy_files = {npy_files}")
    
    # Load each .npy file
    for npy in npy_files:
        go_ids = np.load(os.path.join(zip_path, npy))
        vec = np.load(os.path.join(zip_path, npy.replace('_ids', '_vecs')))
        for go_id, vec in zip(go_ids, vec):
            go_id2vec_dict[go_id] = vec.tolist()
            
    return go_id2vec_dict

# loads GO embeddings       
go_id2vec_dict = load_go_embeddings()
    
# Initialize the dataset       
go_dataset = GOEmbeddingsDataset(go_id2vec_dict)

# Initialize the dataloader     
go_loader = DataLoader(
    go_dataset, 
    batch_size=4,  
    shuffle=True
)

# Iterate over batches
for go_ids, go_vecs in go_loader:
    # go_ids: Batch of GO IDs  
    # go_vecs: Batch of corresponding GO embeddings
    break

In [59]:
# https://www.proteomicsdb.org/protein/56464/summary , https://www.proteomicsdb.org/protein/51836/summary , 
protein_sequence = ["MPGIVELPTLEELKVDEVKISSAVLKAAAHHYGAQCDKPNKEFMLCRWEEKDPRRCLEEGKLVNKCALDFFRQIKRHCAEPFTEYWTCIDYTGQQLFRHCRKQQAKFDECVLDKLGWVRPDLGELSKVTKVKTDRPLPENPYHSRPRPDPSPEIEGDLQPATHGSRFYFWTK",
                    "MTAKMETTFYDDALNASFLPSESGPYGYSNPKILKQSMTLNLADPVGSLKPHLRAKNSDLLTSPDVGLLKLASPELERLIIQSSNGHITTTPTPTQFLCPKNVTDEQEGFAEGFVRALAELHSQNTLPSVTSAAQPVNGAGMVAPAVASVAGGSGSGGFSASLHSEPPVYANLSNFNPGALSSGGGAPSYGAAGLAFPAQPQQQQQPPHHLPQQMPVQHPRLQALKEEPQTVPEMPGETPPLSPIDMESQERIKAERKRMRNRIAASKCRKRKLERIARLEEKVKTLKAQNSELASTANMLREQVAQLKQKVMNHVNSGCQLMLTQQLQTF"
                   ]

dataset = ProteinDataset(protein_sequence)

# Create data loader 
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Iterate through data loader
for esm_embs, esmfold_embs in loader:
    # esm_embs is a batch of ESM embeddings
    # esmfold_embs is a batch of ESMFold embeddings
    # Get sizes of embeddings
    last_hidden_state = esm_embs.last_hidden_state
    esm_sz = last_hidden_state.size() 
    
    esm_emb, esmfold_emb = esmfold_embs 
    #esmfold_emb = esmfold_emb.last_hidden_state
    #esmfold_sz = esmfold_emb.size()

    # Print sizes
    print(f"ESM embedding size: {esm_sz}")  
    #print(f"ESMFold embedding size: {esmfold_sz}")

    # Prints ESM and ESMFold embeddings
    print(f"esm_embs, {esm_embs}")
    print(f"esmfold_emb, {esmfold_emb}")

    # Do training using combined ESM and ESMFold embeddings
    

ESM embedding size: torch.Size([2, 1, 1026, 2560])
esm_embs, BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[[-0.0184,  0.0552, -0.0875,  ...,  0.0951, -0.0404, -0.1598],
          [-0.1159, -0.0938,  0.1529,  ...,  0.0591,  0.0315, -0.2219],
          [ 0.1499,  0.0846, -0.1652,  ...,  0.0591, -0.2045,  0.0263],
          ...,
          [-0.1100,  0.0565,  0.0057,  ...,  0.0742, -0.0928, -0.1681],
          [-0.0486,  0.0130,  0.0466,  ...,  0.0539, -0.1000, -0.2075],
          [-0.0050, -0.0296,  0.1553,  ...,  0.0214, -0.1564, -0.2101]]],


        [[[-0.0648,  0.0021, -0.0788,  ...,  0.1190, -0.0258, -0.0679],
          [-0.0809, -0.1543, -0.0164,  ...,  0.0897, -0.1649, -0.1291],
          [-0.0204,  0.0103, -0.0496,  ...,  0.0484, -0.2328, -0.0971],
          ...,
          [-0.0315,  0.0047, -0.0794,  ...,  0.1127,  0.0082, -0.0885],
          [-0.0009,  0.0168, -0.0818,  ...,  0.0992,  0.0012, -0.0693],
          [ 0.0741,  0.0683, -0.0774,  ...,  0.048