In [32]:
import numpy as np
import pandas as pd
import matplotlib as plt
import seaborn as sns
from time import time
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy
from transformers import BertModel, BertTokenizer
from Bio import SeqIO

In [33]:
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed_all(42)

In [34]:
class config:
    num_labels = 500
    n_epochs = 25
    batch_size = 128
    lr = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    embeds_dim = 1024

In [35]:
class paths:
    train_ids = "/kaggle/input/protbert-embeddings-for-cafa5/train_ids.npy"    
    train_embeddings = "/kaggle/input/protbert-embeddings-for-cafa5/train_embeddings.npy"
    test_ids = "/kaggle/input/protbert-embeddings-for-cafa5/test_ids.npy"
    test_embeddings = "/kaggle/input/protbert-embeddings-for-cafa5/test_embeddings.npy"
    train_labels_path = "/kaggle/input/cafa-5-protein-function-prediction/Train/train_terms.tsv"
    train_targets_path = "/kaggle/input/cafa5-label-vectors-numpy/train_targets_top500.npy"

In [36]:
class CustomProteinDataset(Dataset):
    def __init__(self, train=True):
        super(CustomProteinDataset).__init__()
        self.train=train
        if train:
            embeds = np.load(paths.train_embeddings)
            ids = np.load(paths.train_ids)
        else:
            embeds = np.load(paths.test_embeddings)
            ids = np.load(paths.test_ids)
        
        embeds_list = []
        for l in range(embeds.shape[0]):
            embeds_list.append(embeds[l,:])
        self.df = pd.DataFrame(data={"EntryID": ids, "embed" : embeds_list})
        
        if train:
            self.labels = np.load(paths.train_targets_path)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype = torch.float32)
        if self.train:
            targets = torch.tensor(self.labels[index, :], dtype = torch.float32)
            return embed, targets
        else:
            id = self.df.iloc[index]["EntryID"]
            return embed, id

In [37]:
train_dataset = CustomProteinDataset()
test_dataset = CustomProteinDataset(train=False)

In [38]:
len(train_dataset)

142246

In [39]:
class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, 1012)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1012, 712)
        self.activation2 = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(712, num_classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x

In [40]:
class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=int(8 * input_dim/4), out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=num_classes)

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [41]:
def train(model_type="linear", train_size=0.9):
    train_set, val_set = random_split(train_dataset, lengths = [int(len(train_dataset)*train_size), len(train_dataset)-int(len(train_dataset)*train_size)])
    train_dataloader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=config.batch_size, shuffle=True)
    
    if model_type == "linear":
        model = MultiLayerPerceptron(input_dim=config.embeds_dim, num_classes=config.num_labels).to(config.device)
    if model_type == "convolutional":
        model = CNN1D(input_dim=config.embeds_dim, num_classes=config.num_labels).to(config.device)
        
    optimizer = torch.optim.Adam(model.parameters(), lr = config.lr)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=1)
    CrossEntropy = torch.nn.CrossEntropyLoss()
    f1_score = MultilabelF1Score(num_labels=config.num_labels).to(config.device)
    n_epochs = config.n_epochs
    
    train_loss_history=[]
    val_loss_history=[]
    
    train_f1score_history=[]
    val_f1score_history=[]
    
    for epoch in range(n_epochs):
        print("EPOCH ", epoch+1)
        ## TRAIN PHASE :
        losses = []
        scores = []
        for embed, targets in tqdm(train_dataloader):
            embed, targets = embed.to(config.device), targets.to(config.device)
            optimizer.zero_grad()
            preds = model(embed)
            loss= CrossEntropy(preds, targets)
            score=f1_score(preds, targets)
            losses.append(loss.item()) 
            scores.append(score.item())
            loss.backward()
            optimizer.step()
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average TRAIN Loss : ", avg_loss)
        print("Running Average TRAIN F1-Score : ", avg_score)
        train_loss_history.append(avg_loss)
        train_f1score_history.append(avg_score)\
        
        losses = []
        scores = []
        for embed, targets in val_dataloader:
            embed, targets = embed.to(config.device), targets.to(config.device)
            preds = model(embed)
            loss= CrossEntropy(preds, targets)
            score=f1_score(preds, targets)
            losses.append(loss.item())
            scores.append(score.item())
        avg_loss = np.mean(losses)
        avg_score = np.mean(scores)
        print("Running Average VAL Loss : ", avg_loss)
        print("Running Average VAL F1-Score : ", avg_score)
        val_loss_history.append(avg_loss)
        val_f1score_history.append(avg_score)
        
        scheduler.step(avg_loss)
        print("\n")
        
    print("TRAINING FINISHED")
    print("FINAL TRAINING SCORE : ", train_f1score_history[-1])
    print("FINAL VALIDATION SCORE : ", val_f1score_history[-1])
    
    losses_history = {"train" : train_loss_history, "val" : val_loss_history}
    scores_history = {"train" : train_f1score_history, "val" : val_f1score_history}
    
    return model, losses_history, scores_history

In [42]:
model, losses_history, scores_history = train(model_type="linear", train_size=0.95)

EPOCH  1


100%|██████████| 1056/1056 [00:45<00:00, 23.38it/s]


Running Average TRAIN Loss :  138.00875971534035
Running Average TRAIN F1-Score :  0.10431985564487563
Running Average VAL Loss :  134.5035569327218
Running Average VAL F1-Score :  0.13277065607586078


EPOCH  2


100%|██████████| 1056/1056 [00:46<00:00, 22.73it/s]


Running Average TRAIN Loss :  135.07629364187068
Running Average TRAIN F1-Score :  0.14090815848304014
Running Average VAL Loss :  132.99446432931083
Running Average VAL F1-Score :  0.14797663063343083


EPOCH  3


100%|██████████| 1056/1056 [00:46<00:00, 22.90it/s]


Running Average TRAIN Loss :  133.96682144656324
Running Average TRAIN F1-Score :  0.1535308060560827
Running Average VAL Loss :  132.1038383756365
Running Average VAL F1-Score :  0.156761730621968


EPOCH  4


100%|██████████| 1056/1056 [00:45<00:00, 23.02it/s]


Running Average TRAIN Loss :  133.20018368056327
Running Average TRAIN F1-Score :  0.16114157686630884
Running Average VAL Loss :  131.64523492540633
Running Average VAL F1-Score :  0.16388671632323945


EPOCH  5


100%|██████████| 1056/1056 [00:46<00:00, 22.83it/s]


Running Average TRAIN Loss :  132.5856273535526
Running Average TRAIN F1-Score :  0.1672656626906246
Running Average VAL Loss :  131.29761477879114
Running Average VAL F1-Score :  0.1671999583819083


EPOCH  6


100%|██████████| 1056/1056 [00:45<00:00, 23.15it/s]


Running Average TRAIN Loss :  132.09052466623712
Running Average TRAIN F1-Score :  0.17228831706399267
Running Average VAL Loss :  131.0025987625122
Running Average VAL F1-Score :  0.16889282715107715


EPOCH  7


100%|██████████| 1056/1056 [00:45<00:00, 23.02it/s]


Running Average TRAIN Loss :  131.59908506364533
Running Average TRAIN F1-Score :  0.17653172062427708
Running Average VAL Loss :  130.84913390023368
Running Average VAL F1-Score :  0.1739802443023239


EPOCH  8


100%|██████████| 1056/1056 [00:46<00:00, 22.88it/s]


Running Average TRAIN Loss :  131.2212935072003
Running Average TRAIN F1-Score :  0.18082432767771411
Running Average VAL Loss :  130.12302221570695
Running Average VAL F1-Score :  0.17921998990433557


EPOCH  9


100%|██████████| 1056/1056 [00:46<00:00, 22.84it/s]


Running Average TRAIN Loss :  130.7778364239317
Running Average TRAIN F1-Score :  0.18514902827640375
Running Average VAL Loss :  129.99260098593575
Running Average VAL F1-Score :  0.179684135264584


EPOCH  10


100%|██████████| 1056/1056 [00:46<00:00, 22.57it/s]


Running Average TRAIN Loss :  130.39919090270996
Running Average TRAIN F1-Score :  0.18889884694451184
Running Average VAL Loss :  130.4820717402867
Running Average VAL F1-Score :  0.1816998893128974


EPOCH  11


100%|██████████| 1056/1056 [00:45<00:00, 23.20it/s]


Running Average TRAIN Loss :  130.05998239372715
Running Average TRAIN F1-Score :  0.19293708491110892
Running Average VAL Loss :  130.05918734414237
Running Average VAL F1-Score :  0.18423451536468097


EPOCH  12


100%|██████████| 1056/1056 [00:45<00:00, 23.33it/s]


Running Average TRAIN Loss :  128.61242331880513
Running Average TRAIN F1-Score :  0.20184582333560241
Running Average VAL Loss :  129.14695971352714
Running Average VAL F1-Score :  0.1901145726442337


EPOCH  13


100%|██████████| 1056/1056 [00:45<00:00, 23.33it/s]


Running Average TRAIN Loss :  128.37160655946442
Running Average TRAIN F1-Score :  0.2043950503099371
Running Average VAL Loss :  129.09178039005823
Running Average VAL F1-Score :  0.18989254693899835


EPOCH  14


100%|██████████| 1056/1056 [00:45<00:00, 23.28it/s]


Running Average TRAIN Loss :  128.26470320874995
Running Average TRAIN F1-Score :  0.20651188348843294
Running Average VAL Loss :  129.0312761579241
Running Average VAL F1-Score :  0.19207158498466015


EPOCH  15


100%|██████████| 1056/1056 [00:45<00:00, 23.26it/s]


Running Average TRAIN Loss :  128.159089940967
Running Average TRAIN F1-Score :  0.20753209196934194
Running Average VAL Loss :  129.05208574022566
Running Average VAL F1-Score :  0.19280411888446128


EPOCH  16


100%|██████████| 1056/1056 [00:44<00:00, 23.62it/s]


Running Average TRAIN Loss :  128.07091607469502
Running Average TRAIN F1-Score :  0.20841132529136358
Running Average VAL Loss :  128.83642101287842
Running Average VAL F1-Score :  0.19366810470819473


EPOCH  17


100%|██████████| 1056/1056 [00:44<00:00, 23.72it/s]


Running Average TRAIN Loss :  127.99045710852651
Running Average TRAIN F1-Score :  0.209959879856218
Running Average VAL Loss :  129.13444914136613
Running Average VAL F1-Score :  0.19486841612628528


EPOCH  18


100%|██████████| 1056/1056 [00:45<00:00, 23.40it/s]


Running Average TRAIN Loss :  127.91688354810078
Running Average TRAIN F1-Score :  0.21017102538749124
Running Average VAL Loss :  129.09949398040771
Running Average VAL F1-Score :  0.19697332648294313


EPOCH  19


100%|██████████| 1056/1056 [00:44<00:00, 23.60it/s]


Running Average TRAIN Loss :  127.7109381719069
Running Average TRAIN F1-Score :  0.21193268262978757
Running Average VAL Loss :  128.83574540274483
Running Average VAL F1-Score :  0.19396526286644594


EPOCH  20


100%|██████████| 1056/1056 [00:45<00:00, 23.37it/s]


Running Average TRAIN Loss :  127.6863477230072
Running Average TRAIN F1-Score :  0.21245509331029924
Running Average VAL Loss :  128.93520205361503
Running Average VAL F1-Score :  0.19535395982010023


EPOCH  21


100%|██████████| 1056/1056 [00:44<00:00, 23.55it/s]


Running Average TRAIN Loss :  127.65867590181756
Running Average TRAIN F1-Score :  0.21241225983778184
Running Average VAL Loss :  128.6876839229039
Running Average VAL F1-Score :  0.19563207801963603


EPOCH  22


100%|██████████| 1056/1056 [00:44<00:00, 23.54it/s]


Running Average TRAIN Loss :  127.66321536988923
Running Average TRAIN F1-Score :  0.21243242321140837
Running Average VAL Loss :  128.9073532649449
Running Average VAL F1-Score :  0.19624777430934565


EPOCH  23


100%|██████████| 1056/1056 [00:46<00:00, 22.82it/s]


Running Average TRAIN Loss :  127.65119189927073
Running Average TRAIN F1-Score :  0.21269499566970448
Running Average VAL Loss :  128.88768305097307
Running Average VAL F1-Score :  0.19474040583840438


EPOCH  24


100%|██████████| 1056/1056 [00:44<00:00, 23.54it/s]


Running Average TRAIN Loss :  127.65395657943957
Running Average TRAIN F1-Score :  0.21236839286531461
Running Average VAL Loss :  128.89639486585344
Running Average VAL F1-Score :  0.1962287463247776


EPOCH  25


100%|██████████| 1056/1056 [00:44<00:00, 23.57it/s]


Running Average TRAIN Loss :  127.65266032652421
Running Average TRAIN F1-Score :  0.2125067676586861
Running Average VAL Loss :  128.90435246058874
Running Average VAL F1-Score :  0.19544664131743567


TRAINING FINISHED
FINAL TRAINING SCORE :  0.2125067676586861
FINAL VALIDATION SCORE :  0.19544664131743567


In [43]:
def predict():
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
        
    model.eval()
    
    labels = pd.read_csv(paths.train_labels_path, sep = "\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[:config.num_labels].index.values
    print("GENERATE PREDICTION FOR TEST SET...")

    ids_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    go_terms_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=object)
    confs_ = np.empty(shape=(len(test_dataloader)*config.num_labels,), dtype=np.float32)

    for i, (embed, id) in tqdm(enumerate(test_dataloader)):
        embed = embed.to(config.device)
        confs_[i*config.num_labels:(i+1)*config.num_labels] = torch.nn.functional.sigmoid(model(embed)).squeeze().detach().cpu().numpy()
        ids_[i*config.num_labels:(i+1)*config.num_labels] = id[0]
        go_terms_[i*config.num_labels:(i+1)*config.num_labels] = labels_names

    submission_df = pd.DataFrame(data={"Id" : ids_, "GO term" : go_terms_, "Confidence" : confs_})
    print("PREDICTIONS DONE")
    return submission_df

In [44]:
submission_df = predict()

GENERATE PREDICTION FOR TEST SET...


141865it [01:40, 1408.81it/s]


PREDICTIONS DONE


In [45]:
submission_df.head(50)

Unnamed: 0,Id,GO term,Confidence
0,Q9CQV8,GO:0005575,0.876242
1,Q9CQV8,GO:0008150,0.887685
2,Q9CQV8,GO:0110165,0.871613
3,Q9CQV8,GO:0003674,0.881831
4,Q9CQV8,GO:0005622,0.848041
5,Q9CQV8,GO:0009987,0.837071
6,Q9CQV8,GO:0043226,0.808521
7,Q9CQV8,GO:0043229,0.790866
8,Q9CQV8,GO:0005488,0.879563
9,Q9CQV8,GO:0043227,0.787817


In [46]:
len(submission_df)

70932500

In [47]:
submission_df.to_csv('submission.tsv', sep='\t', index=False)