## Install required libraries

In [1]:
!pip install duckdb --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
!pip install polars --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
!pip install biopython

Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg
Collecting biopython
  Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


## Import required modules

In [2]:
import torch

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import random
import pprint
import os
import duckdb as dd
import polars as pl
from Bio import SeqIO

import transformers

In [4]:
## Run this twice, first time, it will show error
from transformers import AutoTokenizer, EsmModel

## Declare necessary global variables

In [5]:
model_to_use = 'facebook/esm2_t6_8M_UR50D'

train_fasta_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
test_fasta_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'

output_train_embed_dir = "/kaggle/working/esm_embeddings/train"
output_test_embed_dir = "/kaggle/working/esm_embeddings/test"

## Create the embeddings generation function

In [6]:
from tqdm.notebook import tqdm # Use tqdm for a progress bar in Kaggle/Jupyter

def generate_embeddings_optimized(fasta_file, output_dir, model_name='facebook/esm2_t12_35M_UR50D', batch_size=4):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load model in float16 for memory efficiency (Requires modern GPU/hardware)
    model = EsmModel.from_pretrained(model_name, torch_dtype=torch.float16) 
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    os.makedirs(output_dir, exist_ok=True)

    # Read all sequences into a list first to prepare for batching
    sequences = list(SeqIO.parse(fasta_file, "fasta"))
    print(f"Total sequences to process: {len(sequences)}")

    for i in tqdm(range(0, len(sequences), batch_size)):
        batch = sequences[i:i+batch_size]
        seq_ids = [record.id for record in batch]
        seqs = [str(record.seq) for record in batch]

        # Tokenize the batch
        inputs = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1022)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        # Extract and save embeddings individually to manage memory
        for j, seq_id in enumerate(seq_ids):
            # Mean pooling for sequence-level embedding
            embedding = outputs.last_hidden_state[j].mean(dim=0).cpu().numpy()
            output_path = os.path.join(output_dir, f"{seq_id}.pt")
            torch.save({'embeddings': embedding}, output_path)

        # Explicitly clear input and output tensors to free VRAM immediately
        del inputs, outputs, batch
        torch.cuda.empty_cache()

    print("Embedding generation complete.")

## Run the function on train and test FASTA files

In [7]:
generate_embeddings_optimized(train_fasta_file_path, output_train_embed_dir, model_to_use, 8)

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total sequences to process: 82404


  0%|          | 0/10301 [00:00<?, ?it/s]

Embedding generation complete.


In [None]:
generate_embeddings_optimized(test_fasta_file_path, output_test_embed_dir, model_to_use, 8)

## Convert the embeddings into a dataframe

In [8]:
def embeddings_to_dataframe(directory_path):
    """Loads embeddings and organizes them into a Pandas DataFrame."""
    data_list = []
    for filename in tqdm(os.listdir(directory_path)):
        if filename.endswith(".pt"):
            file_path = os.path.join(directory_path, filename)
            data = torch.load(file_path, weights_only=False)
            seq_id = os.path.splitext(filename)[0]
            
            # The embedding data itself is in 'embeddings'
            embedding = data['embeddings'] 
            
            # Create a row for the DataFrame
            row = {'Sequence_ID': seq_id, 'Embedding': embedding}
            data_list.append(row)
            
    #df = pd.DataFrame(data_list)
    df = pl.DataFrame(data_list)
    print(f"DataFrame created with {len(df)} entries.")
    return df

In [10]:
# Example usage:
embeddings_df = embeddings_to_dataframe(output_train_embed_dir)
print(embeddings_df.head())

  0%|          | 0/82404 [00:00<?, ?it/s]

DataFrame created with 82404 entries.
shape: (5, 2)
┌───────────────────────┬─────────────────────────────────┐
│ Sequence_ID           ┆ Embedding                       │
│ ---                   ┆ ---                             │
│ str                   ┆ object                          │
╞═══════════════════════╪═════════════════════════════════╡
│ sp|P07288|KLK3_HUMAN  ┆ [-1.1871e-01 -2.0483e-01  2.88… │
│ sp|P26200|PROF2_DICDI ┆ [-1.5637e-01  5.8075e-02  1.31… │
│ sp|P53617|NRD1_YEAST  ┆ [-1.2909e-02 -8.6853e-02  5.00… │
│ sp|P50150|GBG4_HUMAN  ┆ [-1.1218e-01 -1.4636e-01  1.39… │
│ sp|Q91XE4|ACY3_MOUSE  ┆ [-3.9764e-02 -1.0461e-01  1.65… │
└───────────────────────┴─────────────────────────────────┘


In [11]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID")
      .str.split("|") 
      .alias("Sequence_ID_parts")
)
print(embeddings_df.head())

shape: (5, 3)
┌───────────────────────┬─────────────────────────────────┬─────────────────────────────────┐
│ Sequence_ID           ┆ Embedding                       ┆ Sequence_ID_parts               │
│ ---                   ┆ ---                             ┆ ---                             │
│ str                   ┆ object                          ┆ list[str]                       │
╞═══════════════════════╪═════════════════════════════════╪═════════════════════════════════╡
│ sp|P07288|KLK3_HUMAN  ┆ [-1.1871e-01 -2.0483e-01  2.88… ┆ ["sp", "P07288", "KLK3_HUMAN"]  │
│ sp|P26200|PROF2_DICDI ┆ [-1.5637e-01  5.8075e-02  1.31… ┆ ["sp", "P26200", "PROF2_DICDI"… │
│ sp|P53617|NRD1_YEAST  ┆ [-1.2909e-02 -8.6853e-02  5.00… ┆ ["sp", "P53617", "NRD1_YEAST"]  │
│ sp|P50150|GBG4_HUMAN  ┆ [-1.1218e-01 -1.4636e-01  1.39… ┆ ["sp", "P50150", "GBG4_HUMAN"]  │
│ sp|Q91XE4|ACY3_MOUSE  ┆ [-3.9764e-02 -1.0461e-01  1.65… ┆ ["sp", "Q91XE4", "ACY3_MOUSE"]  │
└───────────────────────┴─────────────────────

In [12]:
embeddings_df = embeddings_df.with_columns(
    pl.col("Sequence_ID_parts").list.get(1).alias("protein_accession_id")
)
print(embeddings_df.head())

shape: (5, 4)
┌───────────────────────┬────────────────────┬───────────────────┬──────────────────────┐
│ Sequence_ID           ┆ Embedding          ┆ Sequence_ID_parts ┆ protein_accession_id │
│ ---                   ┆ ---                ┆ ---               ┆ ---                  │
│ str                   ┆ object             ┆ list[str]         ┆ str                  │
╞═══════════════════════╪════════════════════╪═══════════════════╪══════════════════════╡
│ sp|P07288|KLK3_HUMAN  ┆ [-1.1871e-01       ┆ ["sp", "P07288",  ┆ P07288               │
│                       ┆ -2.0483e-01  2.88… ┆ "KLK3_HUMAN"]     ┆                      │
│ sp|P26200|PROF2_DICDI ┆ [-1.5637e-01       ┆ ["sp", "P26200",  ┆ P26200               │
│                       ┆ 5.8075e-02  1.31…  ┆ "PROF2_DICDI"…    ┆                      │
│ sp|P53617|NRD1_YEAST  ┆ [-1.2909e-02       ┆ ["sp", "P53617",  ┆ P53617               │
│                       ┆ -8.6853e-02  5.00… ┆ "NRD1_YEAST"]     ┆                    

In [14]:
embeddings_df['Embedding'][800].shape

(320,)

In [15]:
list_of_lists = embeddings_df['Embedding'].to_list()

In [16]:
numpy_array_fixed = np.array(list_of_lists)

# Create a new Polars Series from the NumPy array
# Polars automatically detects the fixed shape and uses the Array dtype
array_series = pl.Series("embedding_arrays", numpy_array_fixed)

print("\nNew Series Dtype:")
print(array_series.dtype)


New Series Dtype:
Array(Float32, shape=(320,))


In [17]:
embeddings_df_semifinal = embeddings_df.with_columns(
    array_series
).drop("Embedding")

print("\nFinal DataFrame:")
print(embeddings_df_semifinal.head())
print(f"Final dtypes: {embeddings_df_semifinal.dtypes}")


Final DataFrame:
shape: (5, 4)
┌───────────────────────┬───────────────────┬──────────────────────┬──────────────────────────┐
│ Sequence_ID           ┆ Sequence_ID_parts ┆ protein_accession_id ┆ embedding_arrays         │
│ ---                   ┆ ---               ┆ ---                  ┆ ---                      │
│ str                   ┆ list[str]         ┆ str                  ┆ array[f32, 320]          │
╞═══════════════════════╪═══════════════════╪══════════════════════╪══════════════════════════╡
│ sp|P07288|KLK3_HUMAN  ┆ ["sp", "P07288",  ┆ P07288               ┆ [-0.118713, -0.204834, … │
│                       ┆ "KLK3_HUMAN"]     ┆                      ┆ -0.21…                   │
│ sp|P26200|PROF2_DICDI ┆ ["sp", "P26200",  ┆ P26200               ┆ [-0.156372, 0.058075, …  │
│                       ┆ "PROF2_DICDI"…    ┆                      ┆ -0.080…                  │
│ sp|P53617|NRD1_YEAST  ┆ ["sp", "P53617",  ┆ P53617               ┆ [-0.012909, -0.086853, … │
│       

In [18]:
embeddings_df_final = embeddings_df_semifinal.select(['protein_accession_id','embedding_arrays'])
print(embeddings_df_final.head())

shape: (5, 2)
┌──────────────────────┬─────────────────────────────────┐
│ protein_accession_id ┆ embedding_arrays                │
│ ---                  ┆ ---                             │
│ str                  ┆ array[f32, 320]                 │
╞══════════════════════╪═════════════════════════════════╡
│ P07288               ┆ [-0.118713, -0.204834, … -0.21… │
│ P26200               ┆ [-0.156372, 0.058075, … -0.080… │
│ P53617               ┆ [-0.012909, -0.086853, … 0.040… │
│ P50150               ┆ [-0.112183, -0.146362, … -0.08… │
│ Q91XE4               ┆ [-0.039764, -0.104614, … -0.13… │
└──────────────────────┴─────────────────────────────────┘


In [19]:
embeddings_df_final.write_parquet('train_protein_features_esm2_320.parquet')