<a href="https://colab.research.google.com/github/oshbocker/CAFA/blob/main/CAFA_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json

api_token = {"username":"oshbocker","key":"bb9c6a60ba5c39e689d8cf8d15cb8bca"}

import json

with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(api_token, file)

!chmod 600 ~/.kaggle/kaggle.json

In [2]:
!pip install obonet -q
!pip install biopython -q
!pip install torchmetrics -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m63.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import os
import torch

from pathlib import Path


iskaggle = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')

cafa_main_path = Path("cafa-5-protein-function-prediction")
cafa_clean_fasta_path = Path("viktorfairuschin/cafa-5-fasta-files")
esm2_path = Path("viktorfairuschin/cafa-5-ems-2-embeddings-numpy")
protbert_path = Path("henriupton/protbert-embeddings-for-cafa5")
t5_path = Path("sergeifironov/t5embeds")

data_paths = [cafa_clean_fasta_path, esm2_path, protbert_path, t5_path]

if iskaggle:
  content_dir = "/kaggle/input"
else:
  content_dir = "/content"

class CFG:
    train_go_obo_path: str = f"{content_dir}/cafa-5-protein-function-prediction/Train/go-basic.obo"
    train_seq_fasta_path: str = f"{content_dir}/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
    train_terms_path: str = f"{content_dir}/cafa-5-protein-function-prediction/Train/train_terms.tsv"
    train_taxonomy_path: str = f"{content_dir}/cafa-5-protein-function-prediction/Train/train_taxonomy.tsv"
    train_ia_path: str = f"{content_dir}/cafa-5-protein-function-prediction/IA.txt"
    test_sequences_path: str = f"{content_dir}/Test (Targets)/testsuperset.fasta"

    num_labels = 500
    n_epochs = 15
    batch_size = 128
    lr = 0.001

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
if not iskaggle:
  import zipfile,kaggle
  if not cafa_main_path.exists():
    print(cafa_main_path)
    kaggle.api.competition_download_cli(str(cafa_main_path))
    zipfile.ZipFile(f'{cafa_main_path}.zip').extractall(cafa_main_path)
  for data_path in data_paths:
    if not data_path.exists():
      print(data_path)
      kaggle.api.dataset_download_files(str(data_path))
      zipfile.ZipFile(f'{data_path.name}.zip').extractall(data_path.name)

cafa-5-protein-function-prediction
Downloading cafa-5-protein-function-prediction.zip to /content


100%|██████████| 115M/115M [00:04<00:00, 26.0MB/s]



viktorfairuschin/cafa-5-fasta-files
viktorfairuschin/cafa-5-ems-2-embeddings-numpy
henriupton/protbert-embeddings-for-cafa5
sergeifironov/t5embeds


In [5]:
import obonet
import time

import networkx as nx
import pandas as pd
import numpy as np
from Bio import SeqIO
from tqdm import tqdm
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# TORCH MODULES FOR METRICS COMPUTATION :
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy

In [6]:
graph = obonet.read_obo(CFG.train_go_obo_path)

In [7]:
# Information Accretion values for terms
ia_dict = {}
with open(CFG.train_ia_path) as f:
    for line in f:
        (key, value) = line.split("\t")
        ia_dict[key] = float(value.replace("\n",""))

In [8]:
subontology_roots = {'BPO':'GO:0008150',
                     'CCO':'GO:0005575',
                     'MFO':'GO:0003674'}

BPO_len = len(nx.ancestors(graph, subontology_roots['BPO']))
CCO_len = len(nx.ancestors(graph, subontology_roots['CCO']))
MFO_len = len(nx.ancestors(graph, subontology_roots['MFO']))
print(BPO_len, CCO_len, MFO_len, BPO_len + CCO_len + MFO_len)

27941 4042 11262 43245


In [9]:
train_terms = pd.read_csv(CFG.train_terms_path, sep="\t")
train_terms['ia'] = train_terms['term'].apply(lambda x: ia_dict[x])
print(train_terms.shape)
train_terms.head(10)

(5363863, 4)


Unnamed: 0,EntryID,term,aspect,ia
0,A0A009IHW8,GO:0008152,BPO,1.598544
1,A0A009IHW8,GO:0034655,BPO,0.042274
2,A0A009IHW8,GO:0072523,BPO,0.083901
3,A0A009IHW8,GO:0044270,BPO,0.281155
4,A0A009IHW8,GO:0006753,BPO,0.013844
5,A0A009IHW8,GO:1901292,BPO,0.0
6,A0A009IHW8,GO:0044237,BPO,0.10367
7,A0A009IHW8,GO:1901360,BPO,1.237575
8,A0A009IHW8,GO:0008150,BPO,0.0
9,A0A009IHW8,GO:1901564,BPO,0.557385


## Get the leaf predictions for each protein
A leaf prediction is a GO term that doesn't have any children that are also predictions for the protein.

In [10]:
def get_leaf_labels(train_terms, aspect):
    # Group by protein and get a list of all go term predictions
    protein_list_go_term = train_terms[train_terms['aspect'] == aspect].groupby('EntryID')['term'].apply(list)

    leaf_labels = {}
    for i, protein in enumerate(protein_list_go_term.index):
        test_terms = set(protein_list_go_term[protein])
        descendants = set()
        for term in test_terms:
            descendants.update(nx.descendants(graph, term))

        specific_labels = test_terms.difference(descendants)
        leaf_labels[protein] = list(specific_labels)

    return leaf_labels

from collections import deque

# Get ordered list of all ancestors
def leaf_label_ancestors(graph, leaf_labels):
    ordered_edges_dict = {}
    for label in leaf_labels:
        # Get ancestors
        ordered_edges = deque()
        node_is_root = False
        prev_node_id = label
        while node_is_root == False:
            node = graph.nodes[prev_node_id]
            node_is_root = 'is_a' not in node
            if not node_is_root:
                prev_node_id = node['is_a'][0]
                ordered_edges.appendleft(prev_node_id)
        ordered_edges_dict[label] = ordered_edges

    return ordered_edges_dict

In [62]:
CCO_leaf_labels = get_leaf_labels(train_terms, 'CCO')
print(len(CCO_leaf_labels.values()))
CCO_all_leaf_labels = np.unique(np.concatenate(list(CCO_leaf_labels.values())))
print(len(CCO_all_leaf_labels))
CCO_nodes = [n for n,v in graph.nodes(data=True) if v['namespace'] == 'cellular_component']
CCO_ordered_edges = leaf_label_ancestors(graph, CCO_nodes)

92912
2763


In [54]:
graph.nodes['GO:0000001']

{'name': 'mitochondrion inheritance',
 'namespace': 'biological_process',
 'def': '"The distribution of mitochondria, including the mitochondrial genome, into daughter cells after mitosis or meiosis, mediated by interactions between mitochondria and the cytoskeleton." [GOC:mcc, PMID:10873824, PMID:11389764]',
 'synonym': ['"mitochondrial inheritance" EXACT []'],
 'is_a': ['GO:0048308', 'GO:0048311']}

In [12]:
# Get the first node label
data_with_labels = []
# Get all CCO labeled proteins with leaf labels
for protein, label_list in CCO_leaf_labels.items():
    for label in label_list:
        data_with_labels.append({'protein': protein,
                                 'label': label,
                                 'ordered_edges': CCO_ordered_edges[label]})

In [13]:
def get_next_label(rw, depth=1):
    if not rw['labeled'] and len(rw['ordered_edges']) > depth:
        return rw['ordered_edges'][-depth]
    else:
        return rw['balanced_label']

def is_final_balanced_label(rw, balanced_labels, depth=1):
    if not rw['labeled']:
        return (rw['balanced_label'] in balanced_labels) or (len(rw['ordered_edges']) == depth)
    else:
        return True

def get_balanced_labels(label_df, label_count_threshold):
    df = label_df.copy()
    balanced_labels = set([label for label, count in df['label'].value_counts().items()
                           if count >= label_count_threshold])
    df['labeled'] = df['label'].apply(lambda x: x in balanced_labels)
    df['balanced_label'] = df['label'] # df.apply(lambda x: x['label'] if x['labeled'] == True else None, axis=1)
    print(sum(df['labeled'])/df.shape[0])
    print(len(balanced_labels))
    print(df['balanced_label'].value_counts())
    depth = 1
    while sum(df['labeled'])/df.shape[0] < 1:
        df['balanced_label'] = df.apply(lambda x: get_next_label(x, depth=depth), axis=1)
        balanced_labels = set([label for label, count in df['balanced_label'].value_counts().items()
                                   if count >= label_count_threshold])
        df['labeled'] = df.apply(lambda x: is_final_balanced_label(x, balanced_labels, depth=depth), axis=1)
        depth += 1
        print(sum(df['labeled'])/df.shape[0])
        print(len(balanced_labels))
        print(df['balanced_label'].value_counts())

    return df

In [14]:
CCO_df = pd.DataFrame(data_with_labels)

new_CCO_df = get_balanced_labels(CCO_df, 1358)

0.46988094968653105
14
GO:0005829    16981
GO:0005634    12661
GO:0005886    11509
GO:0005737     8335
GO:0005654     8160
              ...  
GO:0036025        1
GO:0097182        1
GO:0061835        1
GO:0061834        1
GO:0071202        1
Name: balanced_label, Length: 2763, dtype: int64
0.6748296856256931
26
GO:0005829    16981
GO:0110165    12989
GO:0005634    12725
GO:0005886    11823
GO:0005737     8932
              ...  
GO:0016471        1
GO:0044174        1
GO:0031002        1
GO:0030929        1
GO:0098539        1
Name: balanced_label, Length: 678, dtype: int64
0.8399723875698798
29
GO:0110165    21336
GO:0005829    17050
GO:0005634    12910
GO:0005886    11823
GO:0005737     8948
              ...  
GO:0099572        1
GO:0033176        1
GO:0097518        1
GO:0005788        1
GO:0042170        1
Name: balanced_label, Length: 272, dtype: int64
0.9202521331733925
29
GO:0110165    25516
GO:0005829    17050
GO:0005634    12938
GO:0005886    11823
GO:0032991    10919
      

## Classify selected nodes of CCO

In [15]:
from sklearn.model_selection import train_test_split

In [16]:
ohe_CCO_df = pd.get_dummies(new_CCO_df, prefix=['balanced_label'], columns=['balanced_label'], drop_first=True)
label_cols = [c for c in ohe_CCO_df.columns if 'balanced_label_' in c]
label_df = ohe_CCO_df.groupby('protein').sum(numeric_only=True)[label_cols].clip(0,1)
label_values = list(label_df.values)
label_ids = list(label_df.index)
labels_df = pd.DataFrame(data={"EntryID": label_ids, "labels_vect": label_values})

# Split into train, val, test
X_train, X_test, y_train, y_test = train_test_split(label_ids,
                                                    label_values,
                                                    test_size = 0.2)
print("Train label", len(y_train))
print("Test label", len(y_test))

Train label 74329
Test label 18583


In [17]:
# Directories for the different embedding vectors :
embeds_map = {
    "T5" : "t5embeds",
    "ProtBERT" : "protbert-embeddings-for-cafa5",
    "ESM2" : "cafa-5-ems-2-embeddings-numpy"
}

# Length of the different embedding vectors :
embeds_dim = {
    "T5" : 1024,
    "ProtBERT" : 1024,
    "ESM2" : 1280,
    "Concat": 1024+1024,
}

In [18]:
class ProteinSequenceDataset(Dataset):

    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype

        if embeddings_source in ["ProtBERT", "ESM2"]:
            embeds = np.load(f"{content_dir}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeddings.npy")
            ids = np.load(f"{content_dir}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        if embeddings_source == "T5":
            embeds = np.load(f"{content_dir}/"+embeds_map[embeddings_source]+"/"+datatype+"_embeds.npy")
            ids = np.load(f"{content_dir}/"+embeds_map[embeddings_source]+"/"+datatype+"_ids.npy")

        if embeddings_source == "Concat":
            bert_embeds = np.load(f"{content_dir}/"+embeds_map["ProtBERT"]+"/"+datatype+"_embeddings.npy")
            bert_ids = np.load(f"{content_dir}/"+embeds_map["ProtBERT"]+"/"+datatype+"_ids.npy")
            esm2_embeds = np.load(f"{content_dir}/"+embeds_map["ESM2"]+"/"+datatype+"_embeddings.npy")
            esm2_ids = np.load(f"{content_dir}/"+embeds_map["ESM2"]+"/"+datatype+"_ids.npy")
            t5_embeds = np.load(f"{content_dir}/"+embeds_map["T5"]+"/"+datatype+"_embeds.npy")
            t5_ids = np.load(f"{content_dir}/"+embeds_map["T5"]+"/"+datatype+"_ids.npy")
            embeds = np.concatenate([bert_embeds, t5_embeds], axis=1)
            embeds_list = []
            for l in range(embeds.shape[0]):
                embeds_list.append(embeds[l,:])
            self.df = pd.DataFrame(data={"EntryID": t5_ids, "embed" : embeds_list})
        else:
            embeds_list = []
            for l in range(embeds.shape[0]):
                embeds_list.append(embeds[l,:])
            self.df = pd.DataFrame(data={"EntryID": ids, "embed" : embeds_list})


        if datatype=="train":
            self.df = self.df.merge(labels_df, on="EntryID")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"] , dtype = torch.float32)
        if self.datatype=="train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype = torch.float32)
            return embed, targets
        if self.datatype=="test":
            id = self.df.iloc[index]["EntryID"]
            return embed, id

In [19]:
datatype = "train"
bert_embeds = np.load(f"{content_dir}/"+embeds_map["ProtBERT"]+"/"+datatype+"_embeddings.npy")
bert_ids = np.load(f"{content_dir}/"+embeds_map["ProtBERT"]+"/"+datatype+"_ids.npy")
esm2_embeds = np.load(f"{content_dir}/"+embeds_map["ESM2"]+"/"+datatype+"_embeddings.npy")
esm2_ids = np.load(f"{content_dir}/"+embeds_map["ESM2"]+"/"+datatype+"_ids.npy")
t5_embeds = np.load(f"{content_dir}/"+embeds_map["T5"]+"/"+datatype+"_embeds.npy")
t5_ids = np.load(f"{content_dir}/"+embeds_map["T5"]+"/"+datatype+"_ids.npy")

In [20]:
bert_embeds.shape

(142246, 1024)

In [21]:
np.concatenate([bert_embeds, t5_embeds], axis=1).shape

(142246, 2048)

In [22]:
print(bert_ids[:5])
print(esm2_ids[:5])
print(t5_ids[:5])

['P20536' 'O73864' 'O95231' 'A0A0B4J1F4' 'P54366']
['Q9ZSA8' 'P25353' 'A0A2R8YCW8' 'G3V5N8' 'A0A140LFN4']
['P20536' 'O73864' 'O95231' 'A0A0B4J1F4' 'P54366']


In [23]:
class MultiLayerPerceptron(torch.nn.Module):

    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, 4*1012)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4*1012, 4*712)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(4*712, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x


class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        # (batch_size, channels, embed_size)
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 3, embed_size)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 3, embed_size/2 = 512)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        # (batch_size, 8, embed_size/2 = 512)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # (batch_size, 8, embed_size/4 = 256)
        self.fc1 = nn.Linear(in_features=int(8 * input_dim/4), out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=num_classes)

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [24]:
def train_model(embeddings_source, model_type="linear", train_size=0.9):

    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source = embeddings_source)

    train_set, val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])
    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=CFG.batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=CFG.batch_size, shuffle=True)

    if model_type == "linear":
        model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=CFG.num_labels).to(CFG.device)
    if model_type == "convolutional":
        model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=CFG.num_labels).to(CFG.device)

    optimizer = torch.optim.Adam(model.parameters(), lr = CFG.lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)
    CrossEntropy = torch.nn.CrossEntropyLoss()
    f1_score = MultilabelF1Score(num_labels=CFG.num_labels).to(CFG.device)
    n_epochs = CFG.n_epochs

    print("BEGIN TRAINING...")
    train_loss_history=[]
    val_loss_history=[]

    train_f1score_history=[]
    val_f1score_history=[]
    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)
        ## TRAIN PHASE :
        losses = []
        scores = []
        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(CFG.device), targets.to(CFG.device)
            optimizer.zero_grad()
            preds = model(embed)
            loss= CrossEntropy(preds, targets)
            score=f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())
            loss.backward()
            optimizer.step()
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average TRAIN Loss : ", avg_loss)
        print("Running Average TRAIN F1-Score : ", avg_score)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)

        ## VALIDATION PHASE :
        losses = []
        scores = []
        for embed, targets in val_dataloader:
            embed, targets = embed.to(CFG.device), targets.to(CFG.device)
            preds = model(embed)
            loss= CrossEntropy(preds, targets)
            score=f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average VAL Loss : ", avg_loss)
        print("Running Average VAL F1-Score : ", avg_score)
        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)

        scheduler.step(avg_loss)
        print("\n")

    print("TRAINING FINISHED")
    print("FINAL TRAINING SCORE : ", train_f1score_history[-1])
    print("FINAL VALIDATION SCORE : ", val_f1score_history[-1])

    losses_history = {"train" : train_loss_history, "val" : val_loss_history}
    scores_history = {"train" : train_f1score_history, "val" : val_f1score_history}

    return model, losses_history, scores_history

In [25]:
train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source = "ESM2")
CFG.num_labels = train_dataset[0][1].shape[0]

In [26]:
cat_model, cat_losses,cat_scores = train_model(embeddings_source="Concat",model_type="linear")

BEGIN TRAINING...
EPOCH  1


100%|██████████| 654/654 [00:17<00:00, 38.44it/s]


Running Average TRAIN Loss :  4.309215068817139
Running Average TRAIN F1-Score :  0.18460498731412472
Running Average VAL Loss :  4.132621435269917
Running Average VAL F1-Score :  0.21345987324028798


EPOCH  2


100%|██████████| 654/654 [00:14<00:00, 46.06it/s]


Running Average TRAIN Loss :  4.028870807875187
Running Average TRAIN F1-Score :  0.22148050155934937
Running Average VAL Loss :  4.012003624275939
Running Average VAL F1-Score :  0.22788507379081152


EPOCH  3


100%|██████████| 654/654 [00:14<00:00, 45.79it/s]


Running Average TRAIN Loss :  3.9144588617738965
Running Average TRAIN F1-Score :  0.23836573473083864
Running Average VAL Loss :  3.972759922889814
Running Average VAL F1-Score :  0.23690510592232011


EPOCH  4


100%|██████████| 654/654 [00:14<00:00, 46.06it/s]


Running Average TRAIN Loss :  3.822534039479877
Running Average TRAIN F1-Score :  0.2512424760227539
Running Average VAL Loss :  3.9246545621793563
Running Average VAL F1-Score :  0.24412091076374054


EPOCH  5


100%|██████████| 654/654 [00:14<00:00, 46.22it/s]


Running Average TRAIN Loss :  3.729467245052349
Running Average TRAIN F1-Score :  0.2637020042593326
Running Average VAL Loss :  3.88666471389875
Running Average VAL F1-Score :  0.25541863457797326


EPOCH  6


100%|██████████| 654/654 [00:14<00:00, 43.92it/s]


Running Average TRAIN Loss :  3.637652509803072
Running Average TRAIN F1-Score :  0.27580714745259066
Running Average VAL Loss :  3.8792266845703125
Running Average VAL F1-Score :  0.26133485587492383


EPOCH  7


100%|██████████| 654/654 [00:15<00:00, 43.58it/s]


Running Average TRAIN Loss :  3.544577402441509
Running Average TRAIN F1-Score :  0.2888175667198062
Running Average VAL Loss :  3.8764705200717873
Running Average VAL F1-Score :  0.27248394305575385


EPOCH  8


100%|██████████| 654/654 [00:14<00:00, 45.96it/s]


Running Average TRAIN Loss :  3.4511416195364903
Running Average TRAIN F1-Score :  0.30117124895528186
Running Average VAL Loss :  3.8537190581021243
Running Average VAL F1-Score :  0.27725819590157025


EPOCH  9


100%|██████████| 654/654 [00:14<00:00, 46.10it/s]


Running Average TRAIN Loss :  3.3512623149685172
Running Average TRAIN F1-Score :  0.315311939321709
Running Average VAL Loss :  3.8992081537638623
Running Average VAL F1-Score :  0.2826612150832398


EPOCH  10


100%|██████████| 654/654 [00:14<00:00, 46.17it/s]


Running Average TRAIN Loss :  3.258155630998291
Running Average TRAIN F1-Score :  0.3290095023514663
Running Average VAL Loss :  3.917451215116945
Running Average VAL F1-Score :  0.2949233442953188


EPOCH  11


100%|██████████| 654/654 [00:14<00:00, 45.88it/s]


Running Average TRAIN Loss :  2.9749038292362786
Running Average TRAIN F1-Score :  0.35359536216164217
Running Average VAL Loss :  3.865806279117114
Running Average VAL F1-Score :  0.302341223783689


EPOCH  12


100%|██████████| 654/654 [00:14<00:00, 46.33it/s]


Running Average TRAIN Loss :  2.9029624356406907
Running Average TRAIN F1-Score :  0.3588415162825803
Running Average VAL Loss :  3.877368659189303
Running Average VAL F1-Score :  0.3030507893186726


EPOCH  13


100%|██████████| 654/654 [00:14<00:00, 46.04it/s]


Running Average TRAIN Loss :  2.8483984371937745
Running Average TRAIN F1-Score :  0.3637708869366835
Running Average VAL Loss :  3.8955236852985538
Running Average VAL F1-Score :  0.3048317550796352


EPOCH  14


100%|██████████| 654/654 [00:14<00:00, 46.29it/s]


Running Average TRAIN Loss :  2.8385145879302187
Running Average TRAIN F1-Score :  0.36457143933583475
Running Average VAL Loss :  3.8948411027046097
Running Average VAL F1-Score :  0.3054901267567726


EPOCH  15


100%|██████████| 654/654 [00:14<00:00, 46.45it/s]


Running Average TRAIN Loss :  2.833039065748909
Running Average TRAIN F1-Score :  0.36570326628488137
Running Average VAL Loss :  3.8953075017014593
Running Average VAL F1-Score :  0.3057885316953267


TRAINING FINISHED
FINAL TRAINING SCORE :  0.36570326628488137
FINAL VALIDATION SCORE :  0.3057885316953267


TRAINING FINISHED
FINAL TRAINING SCORE :  0.31309230539047755
FINAL VALIDATION SCORE :  0.2773439412655896

In [27]:
esm2_model, esm2_losses, esm2_scores = train_model(embeddings_source="ESM2",model_type="linear")

BEGIN TRAINING...
EPOCH  1


100%|██████████| 654/654 [00:12<00:00, 51.27it/s]


Running Average TRAIN Loss :  4.303298401176383
Running Average TRAIN F1-Score :  0.19190664638392058
Running Average VAL Loss :  4.102618054167865
Running Average VAL F1-Score :  0.21781557430959728


EPOCH  2


100%|██████████| 654/654 [00:12<00:00, 51.09it/s]


Running Average TRAIN Loss :  4.038748616836851
Running Average TRAIN F1-Score :  0.22685360060919316
Running Average VAL Loss :  4.032249904658697
Running Average VAL F1-Score :  0.23316022120926477


EPOCH  3


100%|██████████| 654/654 [00:12<00:00, 51.67it/s]


Running Average TRAIN Loss :  3.911429770860468
Running Average TRAIN F1-Score :  0.24438878415374582
Running Average VAL Loss :  3.944485445545144
Running Average VAL F1-Score :  0.24569698596653872


EPOCH  4


100%|██████████| 654/654 [00:12<00:00, 51.31it/s]


Running Average TRAIN Loss :  3.792313178015776
Running Average TRAIN F1-Score :  0.260396705490369
Running Average VAL Loss :  3.894122982678348
Running Average VAL F1-Score :  0.2506396447959012


EPOCH  5


100%|██████████| 654/654 [00:12<00:00, 51.40it/s]


Running Average TRAIN Loss :  3.666811010524038
Running Average TRAIN F1-Score :  0.27573687065996527
Running Average VAL Loss :  3.873795182737586
Running Average VAL F1-Score :  0.2619835164857237


EPOCH  6


100%|██████████| 654/654 [00:12<00:00, 51.32it/s]


Running Average TRAIN Loss :  3.5379262309555615
Running Average TRAIN F1-Score :  0.29163513373617733
Running Average VAL Loss :  3.8717269472879905
Running Average VAL F1-Score :  0.2733974381260676


EPOCH  7


100%|██████████| 654/654 [00:12<00:00, 51.34it/s]


Running Average TRAIN Loss :  3.406507944477443
Running Average TRAIN F1-Score :  0.31077173258824453
Running Average VAL Loss :  3.8624430323300296
Running Average VAL F1-Score :  0.2782411959073315


EPOCH  8


100%|██████████| 654/654 [00:12<00:00, 51.13it/s]


Running Average TRAIN Loss :  3.266247654909023
Running Average TRAIN F1-Score :  0.3297991315159229
Running Average VAL Loss :  3.8736822082571787
Running Average VAL F1-Score :  0.28542896002939305


EPOCH  9


100%|██████████| 654/654 [00:12<00:00, 51.30it/s]


Running Average TRAIN Loss :  3.1294717821506186
Running Average TRAIN F1-Score :  0.34931441699510685
Running Average VAL Loss :  3.9017997277926093
Running Average VAL F1-Score :  0.29173390812253297


EPOCH  10


 36%|███▋      | 238/654 [00:04<00:08, 49.33it/s]


KeyboardInterrupt: ignored

In [None]:
# Weighted F1 Score?

In [None]:
t5_model, t5_losses, t5_scores = train_model(embeddings_source="T5",model_type="linear")

In [None]:
protbert_model, protbert_losses, protbert_scores = train_model(embeddings_source="T5",model_type="linear")

In [None]:
plt.figure(figsize = (10, 4))
plt.plot(esm2_losses["val"], label = "ESM2")
plt.plot(t5_losses["val"], label = "T5")
plt.plot(protbert_losses["val"], label = "ProtBERT")
plt.title("Validation Losses for # Vector Embeddings")
plt.xlabel("Epochs")
plt.ylabel("Average Loss")
plt.legend()
plt.show()

plt.figure(figsize = (10, 4))
plt.plot(esm2_scores["val"], label = "ESM2")
plt.plot(t5_scores["val"], label = "T5")
plt.plot(protbert_scores["val"], label = "ProtBERT")
plt.title("Validation F1-Scores for # Vector Embeddings")
plt.xlabel("Epochs")
plt.ylabel("Average F1-Score")
plt.legend()
plt.show()

## Make Predictions

In [35]:
def predict(embeddings_source):

    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source = embeddings_source)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

    if embeddings_source == "T5":
        model = t5_model
    if embeddings_source == "ProtBERT":
        model = protbert_model
    if embeddings_source == "EMS2":
        model = esm2_model
    if embeddings_source == "Concat":
        model = cat_model

    model.eval()

    labels = pd.read_csv(CFG.train_terms_path, sep = "\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = [label.split('_')[-1] for label in label_cols]
    print("GENERATE PREDICTION FOR TEST SET...")

    ids_ = np.empty(shape=(len(test_dataloader)*CFG.num_labels,), dtype=object)
    go_terms_ = np.empty(shape=(len(test_dataloader)*CFG.num_labels,), dtype=object)
    confs_ = np.empty(shape=(len(test_dataloader)*CFG.num_labels,), dtype=np.float32)

    for i, (embed, id) in tqdm(enumerate(test_dataloader)):
        embed = embed.to(CFG.device)
        confs_[i*CFG.num_labels:(i+1)*CFG.num_labels] = torch.nn.functional.sigmoid(model(embed)).squeeze().detach().cpu().numpy()
        ids_[i*CFG.num_labels:(i+1)*CFG.num_labels] = id[0]
        go_terms_[i*CFG.num_labels:(i+1)*CFG.num_labels] = labels_names

    submission_df = pd.DataFrame(data={"Id" : ids_, "GO term" : go_terms_, "Confidence" : confs_})
    print("PREDICTIONS DONE")
    return submission_df

In [37]:
submission_df = predict("Concat")

GENERATE PREDICTION FOR TEST SET...


141865it [01:49, 1290.70it/s]


PREDICTIONS DONE


In [41]:
submission_df["GO term"].unique()

29

In [45]:
submission_df.head()

Unnamed: 0,Id,GO term,Confidence
0,Q9CQV8,GO:0005615,0.417892
1,Q9CQV8,GO:0005634,0.795579
2,Q9CQV8,GO:0005654,0.27319
3,Q9CQV8,GO:0005730,0.010239
4,Q9CQV8,GO:0005737,0.197824


In [48]:
submission_df['GO term'].unique()

array(['GO:0005615', 'GO:0005634', 'GO:0005654', 'GO:0005730',
       'GO:0005737', 'GO:0005739', 'GO:0005783', 'GO:0005794',
       'GO:0005815', 'GO:0005829', 'GO:0005886', 'GO:0005911',
       'GO:0009986', 'GO:0016020', 'GO:0030312', 'GO:0031090',
       'GO:0031410', 'GO:0032991', 'GO:0043005', 'GO:0043226',
       'GO:0043231', 'GO:0043232', 'GO:0044423', 'GO:0045202',
       'GO:0070013', 'GO:0070062', 'GO:0098590', 'GO:0110165',
       'GO:0140513'], dtype=object)

In [66]:
id_list = []
go_term_list = []
confidence_list = []
for rw in submission_df.values:
  edges = CCO_ordered_edges[rw[1]]
  for edge in edges:
    id_list.append(rw[0])
    go_term_list.append(rw[1])
    confidence_list.append(rw[2])

In [71]:
additional_labels = pd.DataFrame({'Id': id_list, 'GO term': go_term_list, 'Confidence': confidence_list})

In [73]:
submission_df = pd.concat([submission_df, additional_labels]).drop_duplicates()

In [74]:
submission_df

Unnamed: 0,Id,GO term,Confidence
0,Q9CQV8,GO:0005615,4.178925e-01
1,Q9CQV8,GO:0005634,7.955793e-01
2,Q9CQV8,GO:0005654,2.731899e-01
3,Q9CQV8,GO:0005730,1.023925e-02
4,Q9CQV8,GO:0005737,1.978240e-01
...,...,...,...
4114080,A0A3G2FQK2,GO:0070013,3.619519e-05
4114081,A0A3G2FQK2,GO:0070062,1.645986e-03
4114082,A0A3G2FQK2,GO:0098590,1.885436e-04
4114083,A0A3G2FQK2,GO:0110165,9.602937e-01


In [None]:
submission_df.to_csv('submission.tsv', sep='\t', index=False)