In [29]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from esm import pretrained
from proteinclip import model_utils
!export ORT_DISABLE_THREAD_AFFINITY=1

In [11]:
df = pd.read_parquet('../protclip_dataset.parquet')

In [17]:
# Load ESM2-33 model
esm_model, alphabet = pretrained.esm2_t33_650M_UR50D()
esm_model.eval()
batch_converter = alphabet.get_batch_converter()

# Load ProteinCLIP projection head for ESM2-33
proteinclip_model = model_utils.load_proteinclip("esm", 33)


[1;31m2025-04-14 23:34:00.707335831 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 319373, index: 20, mask: {21, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-14 23:34:00.716595987 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 319414, index: 60, mask: {61, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-14 23:34:00.707360281 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 319364, index: 11, mask: {12, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-14 23:34:00.723526357 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 319381, index: 28, mask: {29, }, error code: 22 e

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
esm_model.to(device)

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [27]:
def embed_batch_esm2(seqs, repr_layer=33, device="cuda"):
    batch_data = [(f"seq{i}", seq) for i, seq in enumerate(seqs)]
    _, _, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[repr_layer], return_contacts=False)
    
    embeddings = []
    for i in range(len(seqs)):
        # Extract per-token representations and average (skip BOS and EOS)
        token_reps = results["representations"][repr_layer][i, 1:-1]
        avg_emb = token_reps.mean(0).cpu().numpy()
        embeddings.append(avg_emb)

    return embeddings


In [28]:
def compute_proteinclip_embeddings(df, col="sequence", batch_size=32, device="cuda"):
    clip_embeds = []

    for i in tqdm(range(0, len(df), batch_size)):
        batch_seqs = df[col].iloc[i : i + batch_size].tolist()

        # Step 1: ESM2 embeddings
        esm_embs = embed_batch_esm2(batch_seqs, device=device)

        # Step 2: Normalize + project with ProteinCLIP
        for emb in esm_embs:
            emb = emb / np.linalg.norm(emb)
            clip = proteinclip_model.predict(emb)
            clip_embeds.append(clip)

    return np.stack(clip_embeds)  # shape: (N, 128)


In [35]:
clip_embeddings = compute_proteinclip_embeddings(df, col="sequence", batch_size=256, device="cuda")
print(clip_embeddings.shape)  # (num_sequences, 128)

100%|██████████| 40/40 [07:57<00:00, 11.94s/it]

(10000, 128)





In [38]:
np.array(clip_embeddings).shape

(10000, 128)

In [39]:
len(df)

10000

In [40]:
df['proteinclip_embed'] = np.array(clip_embeddings).tolist()

In [42]:
df.to_parquet('protclip_embed_dataset.parquet')

In [43]:
df.head()

Unnamed: 0,organism,organism_id,name,evidence,function,id,embedding,sequence,proteinclip_embed
0,Homo sapiens (Human),9606,EPHA7,1,Receptor tyrosine kinase which binds promiscuo...,Q15375,"[-0.015253728, 0.016237658, -0.016555615, 0.02...",MVFQTRYPSWIILCYIWLLRFAHTGEAQAAKEVLLLDSKAQQTELE...,"[0.16681893169879913, 0.105362169444561, 0.011..."
1,Homo sapiens (Human),9606,ANXA8,1,This protein is an anticoagulant protein that ...,P13928,"[-0.008352073, 0.00474287, 0.006541474, -0.002...",MAWWKSWIEQEGVTVKSSSHFNPDPDAETLYKAMKGIGTNEQAIID...,"[-0.10789338499307632, 0.06770769506692886, -0..."
2,Homo sapiens (Human),9606,DPY19L2P1,2,Probable C-mannosyltransferase that mediates C...,Q6NXN4,"[-0.00039709447, -0.02393247, -0.014100584, 0....",MKKQGVNPKPLQSSRPSPSKRPYGASPARELEVEKSALGGGKLPGG...,"[0.026129087433218956, 0.1801730990409851, 0.0..."
3,Homo sapiens (Human),9606,NR1D1,1,Transcriptional repressor which coordinates ci...,P20393,"[0.008172105, -0.0116752, -0.016805198, -0.005...",MTTLDSNNNTGGVITYIGSSGSSPSRTSPESLYSDNSNGSFQSLTQ...,"[-0.016533901914954185, -0.0384892001748085, 0..."
4,Homo sapiens (Human),9606,SLC15A2,1,Proton-coupled amino-acid transporter that tra...,Q16348,"[0.0031013805, -0.0019497981, -0.003831747, 0....",MNPFQKNESKETLFSPVSIEEVPPRPPSPPKKPSPTICGSNYPLSI...,"[0.08691893517971039, 0.07942728698253632, -0...."
