## Install required libraries

In [1]:
!pip install obonet networkx

Collecting obonet
  Downloading obonet-1.1.1-py3-none-any.whl.metadata (6.7 kB)
Downloading obonet-1.1.1-py3-none-any.whl (9.2 kB)
Installing collected packages: obonet
Successfully installed obonet-1.1.1


In [2]:
!pip install duckdb --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
!pip install polars --no-index --find-links=/kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg

Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/duck_pkg
Looking in links: /kaggle/input/polars-and-duckdb/kaggle/working/mysitepackages/polars_pkg


In [3]:
!pip install biopython

Collecting biopython
  Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


## Import required modules

In [4]:
import numpy as np
import pandas as pd
import os
import duckdb as dd
import polars as pl
import h5py
import networkx as nx
import obonet as ob

import seaborn as sns
import matplotlib.pyplot as plt

from Bio import SeqIO

## Separating out the GO terms based on Ontology roots

In [5]:
def separate_go_terms_by_namespace(graph):
    """
    Separates GO terms from a networkx graph into BP, CC, and MF categories.

    Args:
        graph: A networkx MultiDiGraph loaded with obonet.

    Returns:
        A dictionary containing lists of GO term IDs for each namespace.
    """
    bp_terms = []
    cc_terms = []
    mf_terms = []

    # Iterate through every node in the graph, accessing its ID (node_id) and attributes (data)
    for node_id, data in graph.nodes(data=True):
        # The namespace is stored in the 'namespace' key of the node's data dictionary
        namespace = data.get('namespace')

        if namespace == 'biological_process':
            bp_terms.append(node_id)
        elif namespace == 'cellular_component':
            cc_terms.append(node_id)
        elif namespace == 'molecular_function':
            mf_terms.append(node_id)

    return {
        'BP': bp_terms,
        'CC': cc_terms,
        'MF': mf_terms
    }

## Load the Gene Ontology (GO) basic OBO file from the OBO Library

In [6]:
go_file_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
print(f"Loading Gene Ontology from {go_file_path}...")
go_graph = ob.read_obo(go_file_path)
print(f"Total terms loaded: {len(go_graph)}")

Loading Gene Ontology from /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo...
Total terms loaded: 40122


## Separate the terms

In [7]:
separated_terms = separate_go_terms_by_namespace(go_graph)

In [8]:
print("\nGO Term Counts by Namespace:")
print(f"Biological Process (BP): {len(separated_terms['BP'])} terms")
print(f"Cellular Component (CC): {len(separated_terms['CC'])} terms")
print(f"Molecular Function (MF): {len(separated_terms['MF'])} terms")

print("\nFirst 5 Cellular Component (CC) terms:")
for term_id in separated_terms['CC'][:5]:
    # You can get the human-readable name using the graph data dictionary
    name = go_graph.nodes[term_id].get('name')
    print(f"- {term_id}: {name}")


GO Term Counts by Namespace:
Biological Process (BP): 25950 terms
Cellular Component (CC): 4041 terms
Molecular Function (MF): 10131 terms

First 5 Cellular Component (CC) terms:
- GO:0000015: phosphopyruvate hydratase complex
- GO:0000109: nucleotide-excision repair complex
- GO:0000110: nucleotide-excision repair factor 1 complex
- GO:0000111: nucleotide-excision repair factor 2 complex
- GO:0000112: nucleotide-excision repair factor 3 complex


## Roots of the three ontologies

In [10]:
go_graph_edges = list(go_graph.edges)

#print(go_graph_edges[0])
print(go_graph.nodes["GO:0008150"])
print("*************")
print(go_graph.nodes["GO:0005575"])
print("*************")
print(go_graph.nodes["GO:0003674"])

{'name': 'biological_process', 'namespace': 'biological_process', 'alt_id': ['GO:0000004', 'GO:0007582', 'GO:0044699'], 'def': '"A biological process is the execution of a genetically-encoded biological module or program. It consists of all the steps required to achieve the specific biological objective of the module. A biological process is accomplished by a particular set of molecular functions carried out by specific gene products (or macromolecular complexes), often in a highly regulated manner and in a particular temporal sequence." [GOC:pdt]', 'comment': "Note that, in addition to forming the root of the biological process ontology, this term is recommended for the annotation of gene products whose biological process is unknown. When this term is used for annotation, it indicates that no information was available about the biological process of the gene product annotated as of the date the annotation was made; the evidence code 'no data' (ND), is used to indicate this.", 'subset'

In [30]:
leaf_nodes = [node for node, in_degree in go_graph.in_degree() if in_degree == 0]
len(leaf_nodes)

21936

In [37]:
leaf = leaf_nodes[100]
#depths = nx.shortest_path_length(go_graph, source=leaf, target='GO:0003674')
depths = nx.shortest_path_length(go_graph, source=leaf)

In [53]:
max(depths.values())

7

In [49]:
depths_from_bp_root = nx.shortest_path_length(go_graph, target="GO:0008150")
depths_from_cc_root = nx.shortest_path_length(go_graph, target="GO:0005575")
depths_from_mf_root = nx.shortest_path_length(go_graph, target="GO:0003674")

In [54]:
print("Max depth of BP ontology -> ",max(depths_from_bp_root.values()))
print("Max depth of CC ontology -> ",max(depths_from_cc_root.values()))
print("Max depth of MF ontology -> ",max(depths_from_mf_root.values()))

Max depth of BP ontology ->  11
Max depth of CC ontology ->  9
Max depth of MF ontology ->  10


In [None]:
id_to_name = {id_: data.get("name") for id_, data in go_graph.nodes(data=True)}

name_to_id = {
    data["name"]: id_ for id_, data in go_graph.nodes(data=True) if "name" in data
}

print("id_to_name['GO:0008150'] -> ", id_to_name['GO:0008150'])
print("name_to_id['cellular_component'] -> ", name_to_id['cellular_component'])

In [None]:
# Find edges to parent terms
node = name_to_id["cellular_component"]
for child, parent, key in go_graph.out_edges(node, keys=True):
    print(f"• {id_to_name[child]} ⟶ {key} ⟶ {id_to_name[parent]}")

In [None]:
# Find edges to children terms
node = name_to_id["cellular_component"]
for child, parent, key in go_graph.in_edges(node, keys=True):
    print(f"• {id_to_name[parent]} ⟵ {key} ⟵ {id_to_name[child]}")
print("**********************")
node = name_to_id["biological_process"]
for child, parent, key in go_graph.in_edges(node, keys=True):
    print(f"• {id_to_name[parent]} ⟵ {key} ⟵ {id_to_name[child]}")
print("**********************")
node = name_to_id["molecular_function"]
for child, parent, key in go_graph.in_edges(node, keys=True):
    print(f"• {id_to_name[parent]} ⟵ {key} ⟵ {id_to_name[child]}")

In [None]:
# Find edges to children terms
node = name_to_id["cellular_component"]
for child, parent, key in go_graph.in_edges(node, keys=True):
    print(f"{parent} ⟵ {key} ⟵ {child}")
    while child:
        print(f"{parent} ⟵ {key} ⟵ {child}")

In [None]:
go_graph.in_edges(node, keys=True)

In [None]:
go_graph.out_edges(node, keys=True)

## Introduce the ESM2 embeddings of dimension 480

In [None]:
emb_df = pl.read_parquet('/kaggle/input/cafa6-protein-go-terms-feat-labels/train_protein_features_esm2_480.parquet')
emb_df.shape

In [None]:
emb_df.head(5)

In [None]:
dd.sql(" select count(distinct(protein_accession_id)) as proteins from emb_df ").pl()

## Bring in the training data; proteins and the corresponding GO terms

In [None]:
train_terms_df = pl.read_csv('/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv', separator='\t')
train_terms_df.shape

In [None]:
train_terms_df.filter(pl.col('term') == 'GO:0000322')

In [None]:
dd.sql(" select count(distinct(EntryID)) as proteins from train_terms_df ").pl()

## Introduce the embeddings into the training data

In [None]:
train_terms_w_embeds = dd.sql("select distinct t1.protein_accession_id, t2.term, t2.aspect \
, t1.embedding_arrays as protein_embedding \
from emb_df t1 \
join train_terms_df t2 \
on t1.protein_accession_id = t2.EntryID").pl()

print(train_terms_w_embeds.shape)
print(train_terms_w_embeds.filter(pl.col('term') == 'GO:0000322'))

In [None]:
train_terms_w_embeds.head(3)

In [None]:
dd.sql("select count(distinct(protein_accession_id)) as uniq_protein_accession_ids from train_terms_w_embeds").pl()

## Create separate dataframes for each root ontology

In [None]:
train_terms_w_embeds_bp = train_terms_w_embeds.filter(pl.col("term").is_in(separated_terms['BP']))
print("train_terms_w_embeds_bp shape -> ",train_terms_w_embeds_bp.shape)

train_terms_w_embeds_cc = train_terms_w_embeds.filter(pl.col("term").is_in(separated_terms['CC']))
print("train_terms_w_embeds_cc shape -> ",train_terms_w_embeds_cc.shape)

train_terms_w_embeds_mf = train_terms_w_embeds.filter(pl.col("term").is_in(separated_terms['MF']))
print("train_terms_w_embeds_mf shape -> ",train_terms_w_embeds_mf.shape)

In [None]:
train_terms_w_embeds_cc.filter(pl.col('term') == 'GO:0000322')

## Get the 1500 most frequently occuring GO terms

In [None]:
plot_df = dd.sql("select term, count(1) as freq from train_terms_w_embeds_bp group by term order by count(1) desc limit 100").pl()
plot_df_pandas = plot_df.to_pandas()

figure, axis = plt.subplots(1, 1, figsize=(12, 6))

bp = sns.barplot(ax=axis, x=plot_df_pandas.term, y=plot_df_pandas.freq)
bp.set_xticklabels(bp.get_xticklabels(), rotation=90, size = 6)
axis.set_title('Top 100 frequent GO term IDs')
bp.set_xlabel("GO term IDs", fontsize = 12)
bp.set_ylabel("Count", fontsize = 12)
plt.show()

In [None]:
plot_df = dd.sql("select term, count(1) as freq from train_terms_w_embeds_cc group by term order by count(1) desc limit 100").pl()
plot_df_pandas = plot_df.to_pandas()

figure, axis = plt.subplots(1, 1, figsize=(12, 6))

bp = sns.barplot(ax=axis, x=plot_df_pandas.term, y=plot_df_pandas.freq)
bp.set_xticklabels(bp.get_xticklabels(), rotation=90, size = 6)
axis.set_title('Top 100 frequent GO term IDs')
bp.set_xlabel("GO term IDs", fontsize = 12)
bp.set_ylabel("Count", fontsize = 12)
plt.show()

In [None]:
plot_df = dd.sql("select term, count(1) as freq from train_terms_w_embeds_mf group by term order by count(1) desc limit 100").pl()
plot_df_pandas = plot_df.to_pandas()

figure, axis = plt.subplots(1, 1, figsize=(12, 6))

bp = sns.barplot(ax=axis, x=plot_df_pandas.term, y=plot_df_pandas.freq)
bp.set_xticklabels(bp.get_xticklabels(), rotation=90, size = 6)
axis.set_title('Top 100 frequent GO term IDs')
bp.set_xlabel("GO term IDs", fontsize = 12)
bp.set_ylabel("Count", fontsize = 12)
plt.show()

In [None]:
from typing import List

num_of_labels = 1500

def train_labels_per_ontology_pl(input_df: pl.DataFrame, input_num_of_labels: int) -> pl.DataFrame:
    """
    Filters a Polars DataFrame to include only the top N most frequent terms.
    
    Args:
        input_df: The input Polars DataFrame with a 'term' column.
        input_num_of_labels: The number of top terms to keep.
        
    Returns:
        A new Polars DataFrame containing only the top N frequent labels.
    """

    # 1. Calculate the top N frequent terms using Polars methods
    top_labels_df = input_df.group_by("term").count().sort("count", descending=True).limit(input_num_of_labels)
    
    # 2. Extract these top terms into a Python list
    labels_list: List[str] = top_labels_df["term"].to_list()
    
    # 3. Filter the original DataFrame to keep only those terms
    train_labels = input_df.filter(pl.col("term").is_in(labels_list))
    
    print(f"Shape of filtered training data: {train_labels.shape}")
    print("Head of filtered training data:")
    print(train_labels.head())

    return train_labels


In [None]:
train_terms_bp = train_labels_per_ontology_pl(train_terms_w_embeds_bp, num_of_labels)
train_terms_cc = train_labels_per_ontology_pl(train_terms_w_embeds_cc, num_of_labels)
train_terms_mf = train_labels_per_ontology_pl(train_terms_w_embeds_mf, num_of_labels)

In [None]:
train_terms_cc.filter(pl.col('term') == 'GO:0000322')

In [None]:
print(dd.sql("select count(distinct(protein_accession_id)) as uniq_protein_accession_ids from train_terms_bp").pl())
print(dd.sql("select count(distinct(protein_accession_id)) as uniq_protein_accession_ids from train_terms_cc").pl())
print(dd.sql("select count(distinct(protein_accession_id)) as uniq_protein_accession_ids from train_terms_mf").pl())

In [None]:
train_protein_labels_bp = train_terms_bp.select(['protein_accession_id','term'])
print(train_protein_labels_bp.head(5))

train_protein_labels_cc = train_terms_cc.select(['protein_accession_id','term'])
print(train_protein_labels_cc.head(5))

train_protein_labels_mf = train_terms_mf.select(['protein_accession_id','term'])
print(train_protein_labels_mf.head(5))

In [None]:
train_protein_labels_cc.filter(pl.col('term') == 'GO:0000322')

In [None]:
def get_pivoted_df(input_df: pl.DataFrame) -> pl.DataFrame:
    train_protein_labels_w_term_presence = input_df.with_columns(
        pl.lit(1.0).alias("presence")
    )

    pivoted_df = train_protein_labels_w_term_presence.pivot(
        index="protein_accession_id",
        on="term",
        values="presence",  
        aggregate_function="sum"
    )
    
    pivoted_df = pivoted_df.fill_null(0)
    print(pivoted_df.head(5))
    return pivoted_df

In [None]:
pivoted_df_bp = get_pivoted_df(train_protein_labels_bp)
pivoted_df_cc = get_pivoted_df(train_protein_labels_cc)
pivoted_df_mf = get_pivoted_df(train_protein_labels_mf)

In [None]:
pivoted_df_cc.group_by(['GO:0000322']).agg(pl.col("protein_accession_id").count())\
.filter(pl.col('GO:0000322')==1.0)

In [None]:
print(pivoted_df_bp.shape)
print(pivoted_df_cc.shape)
print(pivoted_df_mf.shape)

In [None]:
train_protein_features_bp = dd.sql("select distinct protein_accession_id, protein_embedding from train_terms_bp").pl()
print(train_protein_features_bp.head(5))

train_protein_features_cc = dd.sql("select distinct protein_accession_id, protein_embedding from train_terms_cc").pl()
print(train_protein_features_cc.head(5))

train_protein_features_mf = dd.sql("select distinct protein_accession_id, protein_embedding from train_terms_mf").pl()
print(train_protein_features_mf.head(5))

In [None]:
pivoted_df_mf.write_parquet('train_protein_labels_mf.parquet')
train_protein_features_mf.write_parquet('train_protein_features_mf.parquet')

pivoted_df_cc.write_parquet('train_protein_labels_cc.parquet')
train_protein_features_cc.write_parquet('train_protein_features_cc.parquet')

pivoted_df_bp.write_parquet('train_protein_labels_bp.parquet')
train_protein_features_bp.write_parquet('train_protein_features_bp.parquet')

In [None]:
test_prot_ids = [seq_record.id 
                 for seq_record in 
                 SeqIO.parse("/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta", "fasta")]

In [None]:
test_emb_df = pl.read_parquet('/kaggle/input/cafa6-protein-go-terms-feat-labels/test_protein_features_esm2_480.parquet')
test_emb_df.shape

In [None]:
test_embeds = test_emb_df.filter(pl.col("protein_accession_id").is_in(test_prot_ids))
test_embeds.shape

In [None]:
test_emb_df.head(4)

In [None]:
dd.sql("select count(distinct(protein_accession_id)) as uniq_protein_accession_ids from test_embeds").pl()