## Install required libraries

In [1]:
!pip install duckdb --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
!pip install polars --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
!pip install biopython

Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
Collecting biopython
  Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m83.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


## Import required modules

In [2]:
import torch

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import random
import pprint
import os
import duckdb as dd
import polars as pl
from Bio import SeqIO

import transformers

In [4]:
## Run this twice, first time, it will show error
from transformers import AutoTokenizer, EsmModel

## Declare necessary global variables

In [5]:
model_to_use = 'facebook/esm2_t12_35M_UR50D'

train_fasta_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
test_fasta_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'

output_train_embed_dir = "/kaggle/working/esm_embeddings/train"
output_test_embed_dir = "/kaggle/working/esm_embeddings/test"

## Create the embeddings generation function

In [6]:
from tqdm.notebook import tqdm # Use tqdm for a progress bar in Kaggle/Jupyter

def generate_embeddings_optimized(fasta_file, output_dir, model_name='facebook/esm2_t12_35M_UR50D', batch_size=4):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load model in float16 for memory efficiency (Requires modern GPU/hardware)
    model = EsmModel.from_pretrained(model_name, torch_dtype=torch.float16) 
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    os.makedirs(output_dir, exist_ok=True)

    # Read all sequences into a list first to prepare for batching
    sequences = list(SeqIO.parse(fasta_file, "fasta"))
    print(f"Total sequences to process: {len(sequences)}")

    for i in tqdm(range(0, len(sequences), batch_size)):
        batch = sequences[i:i+batch_size]
        seq_ids = [record.id for record in batch]
        seqs = [str(record.seq) for record in batch]

        # Tokenize the batch
        inputs = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1022)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        # Extract and save embeddings individually to manage memory
        for j, seq_id in enumerate(seq_ids):
            # Mean pooling for sequence-level embedding
            embedding = outputs.last_hidden_state[j].mean(dim=0).cpu().numpy()
            output_path = os.path.join(output_dir, f"{seq_id}.pt")
            torch.save({'embeddings': embedding}, output_path)

        # Explicitly clear input and output tensors to free VRAM immediately
        del inputs, outputs, batch
        torch.cuda.empty_cache()

    print("Embedding generation complete.")

## Run the function on train and test FASTA files

In [7]:
generate_embeddings_optimized(train_fasta_file_path, output_train_embed_dir, model_to_use, 8)

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total sequences to process: 82404


  0%|          | 0/10301 [00:00<?, ?it/s]

Embedding generation complete.


In [18]:
generate_embeddings_optimized(test_fasta_file_path, output_test_embed_dir, model_to_use, 8)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total sequences to process: 224309


  0%|          | 0/28039 [00:00<?, ?it/s]

Embedding generation complete.


## Convert the embeddings into a dataframe

In [19]:
def embeddings_to_dataframe(directory_path):
    """Loads embeddings and organizes them into a Pandas DataFrame."""
    data_list = []
    for filename in tqdm(os.listdir(directory_path)):
        if filename.endswith(".pt"):
            file_path = os.path.join(directory_path, filename)
            data = torch.load(file_path, weights_only=False)
            seq_id = os.path.splitext(filename)[0]
            
            # The embedding data itself is in 'embeddings'
            embedding = data['embeddings'] 
            
            # Create a row for the DataFrame
            row = {'Sequence_ID': seq_id, 'Embedding': embedding}
            data_list.append(row)
            
    #df = pd.DataFrame(data_list)
    df = pl.DataFrame(data_list)
    print(f"DataFrame created with {len(df)} entries.")
    return df

In [20]:
# Example usage:
embeddings_df = embeddings_to_dataframe(output_test_embed_dir)
print(embeddings_df.head())

  0%|          | 0/224309 [00:00<?, ?it/s]

DataFrame created with 224309 entries.
shape: (5, 2)
┌─────────────┬─────────────────────────────────┐
│ Sequence_ID ┆ Embedding                       │
│ ---         ┆ ---                             │
│ str         ┆ object                          │
╞═════════════╪═════════════════════════════════╡
│ Q61066      ┆ [-1.3489e-01  4.8401e-02  4.23… │
│ P50390      ┆ [-4.9591e-02 -1.0577e-01  4.65… │
│ P38882      ┆ [-1.4490e-01  1.0150e-01  1.71… │
│ A5AB61      ┆ [-7.3425e-02 -1.2781e-01  2.15… │
│ Q8WYP5      ┆ [-2.0996e-02  9.3307e-03 -1.12… │
└─────────────┴─────────────────────────────────┘


In [21]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID")
      .str.split("|") 
      .alias("Sequence_ID_parts")
)
print(embeddings_df.head())

shape: (5, 3)
┌─────────────┬─────────────────────────────────┬───────────────────┐
│ Sequence_ID ┆ Embedding                       ┆ Sequence_ID_parts │
│ ---         ┆ ---                             ┆ ---               │
│ str         ┆ object                          ┆ list[str]         │
╞═════════════╪═════════════════════════════════╪═══════════════════╡
│ Q61066      ┆ [-1.3489e-01  4.8401e-02  4.23… ┆ ["Q61066"]        │
│ P50390      ┆ [-4.9591e-02 -1.0577e-01  4.65… ┆ ["P50390"]        │
│ P38882      ┆ [-1.4490e-01  1.0150e-01  1.71… ┆ ["P38882"]        │
│ A5AB61      ┆ [-7.3425e-02 -1.2781e-01  2.15… ┆ ["A5AB61"]        │
│ Q8WYP5      ┆ [-2.0996e-02  9.3307e-03 -1.12… ┆ ["Q8WYP5"]        │
└─────────────┴─────────────────────────────────┴───────────────────┘


In [22]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID_parts").list.get(0).alias("protein_accession_id") ## 0 for test proteins, 1 for train proteins
)
print(embeddings_df.head())

shape: (5, 4)
┌─────────────┬─────────────────────────────────┬───────────────────┬──────────────────────┐
│ Sequence_ID ┆ Embedding                       ┆ Sequence_ID_parts ┆ protein_accession_id │
│ ---         ┆ ---                             ┆ ---               ┆ ---                  │
│ str         ┆ object                          ┆ list[str]         ┆ str                  │
╞═════════════╪═════════════════════════════════╪═══════════════════╪══════════════════════╡
│ Q61066      ┆ [-1.3489e-01  4.8401e-02  4.23… ┆ ["Q61066"]        ┆ Q61066               │
│ P50390      ┆ [-4.9591e-02 -1.0577e-01  4.65… ┆ ["P50390"]        ┆ P50390               │
│ P38882      ┆ [-1.4490e-01  1.0150e-01  1.71… ┆ ["P38882"]        ┆ P38882               │
│ A5AB61      ┆ [-7.3425e-02 -1.2781e-01  2.15… ┆ ["A5AB61"]        ┆ A5AB61               │
│ Q8WYP5      ┆ [-2.0996e-02  9.3307e-03 -1.12… ┆ ["Q8WYP5"]        ┆ Q8WYP5               │
└─────────────┴─────────────────────────────────┴───────

In [23]:
embeddings_df['Embedding'][800].shape

(480,)

In [24]:
list_of_lists = embeddings_df['Embedding'].to_list()

In [25]:
numpy_array_fixed = np.array(list_of_lists)

# Create a new Polars Series from the NumPy array
# Polars automatically detects the fixed shape and uses the Array dtype
array_series = pl.Series("embedding_arrays", numpy_array_fixed)

print("\nNew Series Dtype:")
print(array_series.dtype)


New Series Dtype:
Array(Float32, shape=(480,))


In [26]:
embeddings_df_semifinal = embeddings_df.with_columns(
    array_series
).drop("Embedding")

print("\nFinal DataFrame:")
print(embeddings_df_semifinal.head())
print(f"Final dtypes: {embeddings_df_semifinal.dtypes}")


Final DataFrame:
shape: (5, 4)
┌─────────────┬───────────────────┬──────────────────────┬─────────────────────────────────┐
│ Sequence_ID ┆ Sequence_ID_parts ┆ protein_accession_id ┆ embedding_arrays                │
│ ---         ┆ ---               ┆ ---                  ┆ ---                             │
│ str         ┆ list[str]         ┆ str                  ┆ array[f32, 480]                 │
╞═════════════╪═══════════════════╪══════════════════════╪═════════════════════════════════╡
│ Q61066      ┆ ["Q61066"]        ┆ Q61066               ┆ [-0.134888, 0.048401, … -0.113… │
│ P50390      ┆ ["P50390"]        ┆ P50390               ┆ [-0.049591, -0.105774, … 0.181… │
│ P38882      ┆ ["P38882"]        ┆ P38882               ┆ [-0.144897, 0.101501, … 0.0782… │
│ A5AB61      ┆ ["A5AB61"]        ┆ A5AB61               ┆ [-0.073425, -0.127808, … 0.099… │
│ Q8WYP5      ┆ ["Q8WYP5"]        ┆ Q8WYP5               ┆ [-0.020996, 0.009331, … 0.0836… │
└─────────────┴───────────────────┴───

In [27]:
embeddings_df_final = embeddings_df_semifinal.select(['protein_accession_id','embedding_arrays'])
print(embeddings_df_final.head())

shape: (5, 2)
┌──────────────────────┬─────────────────────────────────┐
│ protein_accession_id ┆ embedding_arrays                │
│ ---                  ┆ ---                             │
│ str                  ┆ array[f32, 480]                 │
╞══════════════════════╪═════════════════════════════════╡
│ Q61066               ┆ [-0.134888, 0.048401, … -0.113… │
│ P50390               ┆ [-0.049591, -0.105774, … 0.181… │
│ P38882               ┆ [-0.144897, 0.101501, … 0.0782… │
│ A5AB61               ┆ [-0.073425, -0.127808, … 0.099… │
│ Q8WYP5               ┆ [-0.020996, 0.009331, … 0.0836… │
└──────────────────────┴─────────────────────────────────┘


In [28]:
embeddings_df_final.write_parquet('test_protein_features_esm2_480.parquet')