In [1]:
!pip install duckdb --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
!pip install polars --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
!pip install biopython
#!pip install fair-esm --no-index --find-links=/kaggle/input/suman-fair-esm/kaggle/working/fair_esm-2.0.0-py3-none-any.whl

Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
Collecting biopython
  Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m60.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


In [2]:
import torch

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import random
import pprint
import os
import duckdb as dd
import polars as pl
from Bio import SeqIO

import transformers

In [4]:
## Run this twice, first time, it will show error
from transformers import AutoTokenizer, EsmModel

In [5]:
model_to_use = 'facebook/esm2_t6_8M_UR50D'
fasta_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
output_embed_dir = "/kaggle/working/esm_embeddings"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_to_use)
model = EsmModel.from_pretrained(model_to_use)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model.to(device)
model.eval()

In [None]:
os.makedirs(output_dir, exist_ok=True)

In [None]:
sequences = SeqIO.parse(fasta_file_path, "fasta")

In [None]:
train_sequences_list = list(sequences)
len(train_sequences_list)

In [None]:
last_record = train_sequences_list[-1]  # using Python's list tricks
print(last_record.id)
print(repr(last_record.seq))
print(len(last_record))

In [None]:
len(train_sequences_list[1001:1801])

In [None]:
for record in train_sequences_list[1001:1801]:
    seq_id = record.id
    sequence = str(record.seq)

    # Tokenize the sequence
    # The tokenizer adds special tokens (CLS and SEP) automatically
    inputs = tokenizer(sequence, return_tensors='pt', padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate embeddings
    with torch.no_grad():
        output = model(**inputs)

    # Extract embeddings
    # We usually take the last hidden states
    embeddings = output.last_hidden_state # Shape: (batch_size, sequence_length, hidden_size)

    # To get a single fixed-size embedding for the whole protein (mean pooling), 
    # you can average over the sequence length dimension (excluding CLS/SEP if desired, though often included for simplicity)
    # Here we average over all tokens in the sequence.
    # Alternatively, you can use the representation of the CLS token
    # sequence_level_embedding = output.pooler_output # This only works for certain model configurations
    
    # Mean pooling as an alternative to pooler_output
    # Start from index 1 and end at index -1 to ignore CLS and SEP tokens if they exist, 
    # but the standard practice with ESM is often to include all or use the representation from a specific layer
    # For simplicity, we can do mean pooling across the sequence dimension for now:
    sequence_embeddings = embeddings.mean(dim=1).squeeze().cpu().numpy() # Shape: (hidden_size,)

    # Save the embeddings
    output_path = os.path.join(output_dir, f"{seq_id}.pt")
    torch.save({'embeddings': sequence_embeddings}, output_path) # Save as a dictionary
    print(f"Saved embeddings for {seq_id} to {output_path}")

In [None]:
import gc
gc.collect()

In [None]:
torch.cuda.empty_cache()

In [None]:
del inputs

In [None]:
del output
del embeddings
del sequence_embeddings

In [7]:
from tqdm.notebook import tqdm # Use tqdm for a progress bar in Kaggle/Jupyter

def generate_embeddings_optimized(fasta_file, output_dir, model_name='facebook/esm2_t12_35M_UR50D', batch_size=4):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load model in float16 for memory efficiency (Requires modern GPU/hardware)
    model = EsmModel.from_pretrained(model_name, torch_dtype=torch.float16) 
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    os.makedirs(output_dir, exist_ok=True)

    # Read all sequences into a list first to prepare for batching
    sequences = list(SeqIO.parse(fasta_file, "fasta"))
    print(f"Total sequences to process: {len(sequences)}")

    for i in tqdm(range(0, len(sequences), batch_size)):
        batch = sequences[i:i+batch_size]
        seq_ids = [record.id for record in batch]
        seqs = [str(record.seq) for record in batch]

        # Tokenize the batch
        inputs = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1022)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        # Extract and save embeddings individually to manage memory
        for j, seq_id in enumerate(seq_ids):
            # Mean pooling for sequence-level embedding
            embedding = outputs.last_hidden_state[j].mean(dim=0).cpu().numpy()
            output_path = os.path.join(output_dir, f"{seq_id}.pt")
            torch.save({'embeddings': embedding}, output_path)

        # Explicitly clear input and output tensors to free VRAM immediately
        del inputs, outputs, batch
        torch.cuda.empty_cache()

    print("Embedding generation complete.")

In [8]:
# Example usage with a smaller, memory-efficient model:
generate_embeddings_optimized(fasta_file_path, output_embed_dir, model_to_use, 8)

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total sequences to process: 82404


  0%|          | 0/10301 [00:00<?, ?it/s]

Embedding generation complete.


In [9]:
from zipfile import ZipFile

dirName = '/kaggle/working'
zipName = 'esm2_t36_3B_UR50D_embeds_suman.zip'

# Create a ZipFile Object
with ZipFile(zipName, 'w') as zipObj:
    # Iterate over all the files in directory
    for folderName, subfolders, filenames in os.walk(dirName):
        for filename in filenames:
            if (filename != zipName):
                # create complete filepath of file in directory
                filePath = os.path.join(folderName, filename)
                # Add file to zip
                zipObj.write(filePath)

In [12]:
def embeddings_to_dataframe(directory_path):
    """Loads embeddings and organizes them into a Pandas DataFrame."""
    data_list = []
    for filename in tqdm(os.listdir(directory_path)):
        if filename.endswith(".pt"):
            file_path = os.path.join(directory_path, filename)
            data = torch.load(file_path, weights_only=False)
            seq_id = os.path.splitext(filename)[0]
            
            # The embedding data itself is in 'embeddings'
            embedding = data['embeddings'] 
            
            # Create a row for the DataFrame
            row = {'Sequence_ID': seq_id, 'Embedding': embedding}
            data_list.append(row)
            
    #df = pd.DataFrame(data_list)
    df = pl.DataFrame(data_list)
    print(f"DataFrame created with {len(df)} entries.")
    return df

In [13]:
# Example usage:
embeddings_df = embeddings_to_dataframe("/kaggle/working/esm_embeddings")
print(embeddings_df.head())

  0%|          | 0/82404 [00:00<?, ?it/s]

DataFrame created with 82404 entries.
shape: (5, 2)
┌──────────────────────────┬─────────────────────────────────┐
│ Sequence_ID              ┆ Embedding                       │
│ ---                      ┆ ---                             │
│ str                      ┆ object                          │
╞══════════════════════════╪═════════════════════════════════╡
│ sp|D3ZHV2|MACF1_RAT      ┆ [-1.1731e-01 -1.2988e-01  3.58… │
│ sp|Q6NLB0|GSTL1_ARATH    ┆ [-1.7563e-02 -1.1700e-01  1.25… │
│ sp|P82233|BR1A_RANTE     ┆ [-4.3030e-02 -1.8921e-02  2.26… │
│ sp|Q6P5F6|S39AA_MOUSE    ┆ [-9.6008e-02 -1.5625e-02  5.49… │
│ sp|A0A7E6FSU6|OXDD_OCTVU ┆ [-3.0200e-01 -1.5735e-01  1.91… │
└──────────────────────────┴─────────────────────────────────┘


In [15]:
def get_simple_accession(full_id):
    """Strips a UniProt ID like 'sp|D3ZHV2|MACF1_RAT' to 'D3ZHV2'."""
    parts = full_id.split('|')
    if len(parts) == 3:
        return parts[1]
    return full_id

In [21]:
print(embeddings_df.head())

shape: (5, 2)
┌──────────────────────────┬─────────────────────────────────┐
│ Sequence_ID              ┆ Embedding                       │
│ ---                      ┆ ---                             │
│ str                      ┆ object                          │
╞══════════════════════════╪═════════════════════════════════╡
│ sp|D3ZHV2|MACF1_RAT      ┆ [-1.1731e-01 -1.2988e-01  3.58… │
│ sp|Q6NLB0|GSTL1_ARATH    ┆ [-1.7563e-02 -1.1700e-01  1.25… │
│ sp|P82233|BR1A_RANTE     ┆ [-4.3030e-02 -1.8921e-02  2.26… │
│ sp|Q6P5F6|S39AA_MOUSE    ┆ [-9.6008e-02 -1.5625e-02  5.49… │
│ sp|A0A7E6FSU6|OXDD_OCTVU ┆ [-3.0200e-01 -1.5735e-01  1.91… │
└──────────────────────────┴─────────────────────────────────┘


In [24]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID")
      .str.split("|") 
      .alias("Sequence_ID_parts")
)
print(embeddings_df.head())

shape: (5, 3)
┌──────────────────────────┬─────────────────────────────────┬─────────────────────────────────┐
│ Sequence_ID              ┆ Embedding                       ┆ Sequence_ID_parts               │
│ ---                      ┆ ---                             ┆ ---                             │
│ str                      ┆ object                          ┆ list[str]                       │
╞══════════════════════════╪═════════════════════════════════╪═════════════════════════════════╡
│ sp|D3ZHV2|MACF1_RAT      ┆ [-1.1731e-01 -1.2988e-01  3.58… ┆ ["sp", "D3ZHV2", "MACF1_RAT"]   │
│ sp|Q6NLB0|GSTL1_ARATH    ┆ [-1.7563e-02 -1.1700e-01  1.25… ┆ ["sp", "Q6NLB0", "GSTL1_ARATH"… │
│ sp|P82233|BR1A_RANTE     ┆ [-4.3030e-02 -1.8921e-02  2.26… ┆ ["sp", "P82233", "BR1A_RANTE"]  │
│ sp|Q6P5F6|S39AA_MOUSE    ┆ [-9.6008e-02 -1.5625e-02  5.49… ┆ ["sp", "Q6P5F6", "S39AA_MOUSE"… │
│ sp|A0A7E6FSU6|OXDD_OCTVU ┆ [-3.0200e-01 -1.5735e-01  1.91… ┆ ["sp", "A0A7E6FSU6", "OXDD_OCT… │
└───────────────

In [26]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID_parts").list.get(1).alias("protein_accession_id")
)
print(embeddings_df.head())

shape: (5, 4)
┌──────────────────────────┬────────────────────┬──────────────────────┬──────────────────────┐
│ Sequence_ID              ┆ Embedding          ┆ Sequence_ID_parts    ┆ protein_accession_id │
│ ---                      ┆ ---                ┆ ---                  ┆ ---                  │
│ str                      ┆ object             ┆ list[str]            ┆ str                  │
╞══════════════════════════╪════════════════════╪══════════════════════╪══════════════════════╡
│ sp|D3ZHV2|MACF1_RAT      ┆ [-1.1731e-01       ┆ ["sp", "D3ZHV2",     ┆ D3ZHV2               │
│                          ┆ -1.2988e-01  3.58… ┆ "MACF1_RAT"]         ┆                      │
│ sp|Q6NLB0|GSTL1_ARATH    ┆ [-1.7563e-02       ┆ ["sp", "Q6NLB0",     ┆ Q6NLB0               │
│                          ┆ -1.1700e-01  1.25… ┆ "GSTL1_ARATH"…       ┆                      │
│ sp|P82233|BR1A_RANTE     ┆ [-4.3030e-02       ┆ ["sp", "P82233",     ┆ P82233               │
│                         

In [28]:
# You can then cast the embedding column as before
embeddings_df_final = embeddings_df.with_columns(
    pl.col("Embedding").list.to_array()
)

TypeError: ExprListNameSpace.to_array() missing 1 required positional argument: 'width'

In [14]:
embeddings_df.write_parquet('train_protein_features_esm2.parquet')

ComputeError: cannot write 'Object' datatype to parquet