In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score, precision_recall_curve
from collections import defaultdict

In [2]:
features = pd.read_table('/work3/s220260/PhosKing1.0/data/clean_data/features.tsv').rename(columns={'#UniProt-ID': 'id', 'position': 'pos'})
features_ids = features['id'].unique()
features

  features = pd.read_table('/work3/s220260/PhosKing1.0/data/clean_data/features.tsv').rename(columns={'#UniProt-ID': 'id', 'position': 'pos'})


Unnamed: 0,id,pos,residue,kinases,species,sources
0,Q8RXG3,12,S,,,"EPSD,PhosPhAt"
1,Q8RXG3,6,S,ASK7,,UniProt
2,Q8RXG3,215,T,,,"UniProt,EPSD"
3,Q8RXG3,221,S,ASK7,,"UniProt,EPSD"
4,Q8RXG3,225,T,ASK7,,"UniProt,EPSD"
...,...,...,...,...,...,...
1645019,Z4YM84,371,S,,,EPSD
1645020,Z4YM84,4,T,,,EPSD
1645021,Z4YMB8,768,S,,,EPSD
1645022,Z4YN06,4,T,,,EPSD


In [3]:
features_seq_org = features[['id', "species"]].drop_duplicates()
features_seq_org = [features_seq_org['id']]
features_seq_org

[0          Q8RXG3
 7          Q8L793
 26         Q2V3V8
 28         Q8GWB3
 31         Q9C505
             ...  
 1645016    Z4YLR9
 1645017    Z4YM84
 1645021    Z4YMB8
 1645022    Z4YN06
 1645023    Z4YN62
 Name: id, Length: 229088, dtype: object]

In [4]:
seq_organisms = pd.read_csv('/work3/s220260/PhosKing1.0/data/clean_data/seq_organisms.csv', header=None).rename(columns={0: 'id', 1: 'organism', 2: 'phylogeny'}).drop_duplicates()
seq_organisms

Unnamed: 0,id,organism,phylogeny
0,D4A7L7,Rattus norvegicus,Eukaryota Metazoa Chordata Craniata Vertebrata...
1,P70365,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...
2,P55096,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...
3,A0A0G2K2P5,Rattus norvegicus,Eukaryota Metazoa Chordata Craniata Vertebrata...
4,P61968,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...
...,...,...,...
205776,P53567,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...
205777,Q8VG33,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...
205778,Q9LUL2,Arabidopsis thaliana,Eukaryota Viridiplantae Streptophyta Embryophy...
205779,Q9BQ48,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...


In [5]:
def group_organism(row: pd.Series):
    organism = row['organism']
    phylogeny = row['phylogeny']
    conditions = {
        'human': 'homo' in organism.lower() or 'homo' in phylogeny.lower(),
        'mammalian': 'mammalia' in phylogeny.lower() or 'mus musculus' in organism.lower() or 'rattus' in organism.lower() or 'sus scrofa' in organism.lower(),
        'other_animal': 'metazoa' in phylogeny.lower(),
        'bacteria': 'bacteria' in phylogeny.lower(),
        'fungal': 'fungi' in phylogeny.lower(),
        'plant': 'plantae' in phylogeny.lower(),
        'other_eukaryotes': 'amoebozoa' in phylogeny.lower() or 'sar' in phylogeny.lower() or 'discoba' in phylogeny.lower() or 'cryptophyceae' in phylogeny.lower() or 'rhodophyta' in phylogeny.lower() or 'haptista' in phylogeny.lower() or 'giardia' in phylogeny.lower() or 'cyanophora' in phylogeny.lower(),
        'virus': 'virus' in organism.lower() or 'virus' in phylogeny.lower(),
        'archaea': 'archaea' in phylogeny.lower(),
    }
    priority = ['human', 'mammalian', 'other_animal', 'bacteria', 'fungal', 'plant', 'other_eukaryotes', 'virus', 'archaea']
    for org_group in priority:
        if conditions[org_group]:
            return org_group
    return None
    

In [6]:
seq_organisms['group'] = seq_organisms.apply(group_organism, axis=1)
seq_organisms

Unnamed: 0,id,organism,phylogeny,group
0,D4A7L7,Rattus norvegicus,Eukaryota Metazoa Chordata Craniata Vertebrata...,mammalian
1,P70365,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...,mammalian
2,P55096,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...,mammalian
3,A0A0G2K2P5,Rattus norvegicus,Eukaryota Metazoa Chordata Craniata Vertebrata...,mammalian
4,P61968,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...,human
...,...,...,...,...
205776,P53567,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...,human
205777,Q8VG33,Mus musculus,Eukaryota Metazoa Chordata Craniata Vertebrata...,mammalian
205778,Q9LUL2,Arabidopsis thaliana,Eukaryota Viridiplantae Streptophyta Embryophy...,plant
205779,Q9BQ48,Homo sapiens,Eukaryota Metazoa Chordata Craniata Vertebrata...,human


In [7]:
missing_group = seq_organisms[seq_organisms['group'].isna()]
print('Unclassified organisms:')
print(*missing_group['phylogeny'].unique().tolist(), sep='\n')
print('\nUnclassified phylogenies:')
print(*missing_group['organism'].unique().tolist(), sep='\n')

Unclassified organisms:
unclassified sequences metagenomes ecological metagenomes
unclassified sequences environmental samples
unclassified sequences metagenomes

Unclassified phylogenies:
bioreactor metagenome
mine drainage metagenome
uncultured organism
uncultured organism Bio4
metagenome


In [8]:
  
true_phosphos = defaultdict(set)
true_phosphos.update(features.groupby('id')['pos'].apply(set))

def get_truth(row: pd.Series):
    return int(row['pos'] in true_phosphos[row['id']])


In [9]:
netphos4_test = pd.read_table('pre_model1_test_.tsv', header=None).rename(columns={0: 'id', 1: 'pos', 2: '?', 3: 'score'}).drop('?', axis=1)
assert len(netphos4_test) == len(netphos4_test[netphos4_test['id'].isin(features_ids)]), 'Mismatch between features and test set'
netphos4_test['y_true'] = netphos4_test.apply(get_truth, axis=1)
assert len(netphos4_test[netphos4_test['y_true'].isna()]) == 0
print(f'AUC: {roc_auc_score(netphos4_test["y_true"], netphos4_test["score"]):.3f}')
netphos4_test

AUC: 0.865


Unnamed: 0,id,pos,score,y_true
0,Q5PPN7,2,0.472,0
1,Q5PPN7,5,0.551,0
2,Q5PPN7,9,0.474,0
3,Q5PPN7,17,0.433,0
4,Q5PPN7,24,0.395,0
...,...,...,...,...
205410,Q58FZ9,195,0.232,0
205411,Q58FZ9,196,0.247,0
205412,Q58FZ9,204,0.128,0
205413,Q58FZ9,206,0.105,0


In [10]:
musitedeep_test = pd.read_table('Test_MusiteDeep_1000.tsv', header=None).rename(columns={0: 'id', 1: 'pos', 2: '?', 3: 'score'}).drop('?', axis=1)
assert len(musitedeep_test) == len(musitedeep_test[musitedeep_test['id'].isin(features_ids)]), 'Mismatch between features and test set'
musitedeep_test['y_true'] = musitedeep_test.apply(get_truth, axis=1)
assert len(musitedeep_test[musitedeep_test['y_true'].isna()]) == 0
print(f'AUC: {roc_auc_score(musitedeep_test["y_true"], musitedeep_test["score"]):.3f}')
musitedeep_test

AUC: 0.759


Unnamed: 0,id,pos,score,y_true
0,P73120,12,0.240,0
1,P73120,14,0.530,0
2,P73120,20,0.264,0
3,P73120,25,0.138,0
4,P73120,28,0.194,0
...,...,...,...,...
98456,P0DPA2,353,0.503,0
98457,P0DPA2,355,0.654,0
98458,P0DPA2,377,0.064,0
98459,P0DPA2,387,0.667,1


In [11]:
netphos3_test = pd.read_table('Test_NetPhos3.1_1000.tsv', header=None).rename(columns={0: 'id', 1: 'pos', 2: '?', 3: 'score'}).drop('?', axis=1)
assert len(netphos3_test) == len(netphos3_test[netphos3_test['id'].isin(features_ids)]), 'Mismatch between features and test set'
netphos3_test['y_true'] = netphos3_test.apply(get_truth, axis=1)
assert len(netphos3_test[netphos3_test['y_true'].isna()]) == 0
print(f'AUC: {roc_auc_score(netphos3_test["y_true"], netphos3_test["score"]):.3f}')
netphos3_test

AUC: 0.626


Unnamed: 0,id,pos,score,y_true
0,A0A088AGW7,3,0.440,0
1,A0A088AGW7,4,0.512,0
2,A0A088AGW7,13,0.443,0
3,A0A088AGW7,22,0.516,0
4,A0A088AGW7,26,0.443,0
...,...,...,...,...
107464,V4Z445,2732,0.816,0
107465,V4Z445,2743,0.660,0
107466,V4Z445,2753,0.544,0
107467,V4Z445,2759,0.439,0


In [12]:
netphospan_test = pd.read_table('Test_NetPhospan1.0.tsv', header=None).rename(columns={0: 'id', 1: 'pos', 2: '?', 3: 'score'}).drop('?', axis=1)
netphospan_test['pos'] += 10
netphospan_test = netphospan_test[netphospan_test['id'].isin(features_ids)]  # wtf
assert len(netphospan_test) == len(netphospan_test[netphospan_test['id'].isin(features_ids)]), 'Mismatch between features and test set'
netphospan_test['y_true'] = netphospan_test.apply(get_truth, axis=1)
assert len(netphospan_test[netphospan_test['y_true'].isna()]) == 0
print(f'AUC: {roc_auc_score(netphospan_test["y_true"], netphospan_test["score"]):.3f}')
netphospan_test

AUC: 0.740


Unnamed: 0,id,pos,score,y_true
0,W4XZI0,13,0.772,1
1,W4XZI0,21,0.124,0
2,W4XZI0,24,0.121,0
3,W4XZI0,27,0.257,0
4,W4XZI0,35,0.317,0
...,...,...,...,...
35454,Q9GU50,464,0.224,0
35455,Q9GU50,465,0.203,0
35456,Q9GU50,467,0.172,0
35457,Q9GU50,469,0.097,0


In [13]:
predictors = {
    'NetPhos 4.0': netphos4_test,
    'MusiteDeep': musitedeep_test,
    'NetPhospan': netphospan_test,
    'NetPhos 3.1': netphos3_test,
}
predictor_aucs = {}
predictor_auprcs = {}
for predictor, df in predictors.items():
    # ROC and AUC
    auc = roc_auc_score(df["y_true"], df["score"])
    fpr, tpr, thresholds = roc_curve(df["y_true"], df["score"])
    roc_curve_data = pd.DataFrame({'FPR': fpr, 'TPR': tpr, 'threshold': thresholds})
    roc_curve_data.to_csv(f'roc_curves/ROC_curve_{predictor.lower().replace(" ", "_")}_global.csv', sep=',', header=True, index=False)
    predictor_aucs[predictor] = auc
    
    # PRC and AUPRC
    auprc = average_precision_score(df["y_true"], df["score"])
    precision, recall, thresholds = precision_recall_curve(df["y_true"], df["score"])
    precision, recall = precision[:-1], recall[:-1]
    prc_curve_data = pd.DataFrame({'precision': precision, 'recall': recall, 'threshold': thresholds})
    prc_curve_data.to_csv(f'roc_curves/PRC_curve_{predictor.lower().replace(" ", "_")}_global.csv', sep=',', header=True, index=False)
    predictor_auprcs[predictor] = auprc

aucs = pd.DataFrame({'auc': predictor_aucs})
aucs.to_csv('roc_curves/AUCs_global.csv', header=True, index=True)

auprcs = pd.DataFrame({'auprc': predictor_auprcs})
auprcs.to_csv('roc_curves/AUPRCs_global.csv', header=True, index=True)
pd.merge(aucs, auprcs, left_index=True, right_index=True)

Unnamed: 0,auc,auprc
MusiteDeep,0.758658,0.195745
NetPhos 3.1,0.625522,0.086267
NetPhos 4.0,0.8646,0.303167
NetPhospan,0.739899,0.140581


**Make inner joins for organism group**

In [14]:
netphos4_test_with_group = netphos4_test.merge(seq_organisms.drop(['phylogeny', 'organism'], axis=1), on='id', how='inner')
print(f'Went from {len(netphos4_test)} phosphos to {len(netphos4_test_with_group)} phosphos')
netphos4_test_with_group

Went from 205415 phosphos to 189180 phosphos


Unnamed: 0,id,pos,score,y_true,group
0,Q5PPN7,2,0.472,0,mammalian
1,Q5PPN7,5,0.551,0,mammalian
2,Q5PPN7,9,0.474,0,mammalian
3,Q5PPN7,17,0.433,0,mammalian
4,Q5PPN7,24,0.395,0,mammalian
...,...,...,...,...,...
189175,Q58FZ9,195,0.232,0,plant
189176,Q58FZ9,196,0.247,0,plant
189177,Q58FZ9,204,0.128,0,plant
189178,Q58FZ9,206,0.105,0,plant


In [15]:
musitedeep_test_with_group = musitedeep_test.merge(seq_organisms.drop(['phylogeny', 'organism'], axis=1), on='id', how='inner')
print(f'Went from {len(musitedeep_test)} phosphos to {len(musitedeep_test_with_group)} phosphos')
musitedeep_test_with_group

Went from 98461 phosphos to 91830 phosphos


Unnamed: 0,id,pos,score,y_true,group
0,P73120,12,0.240,0,bacteria
1,P73120,14,0.530,0,bacteria
2,P73120,20,0.264,0,bacteria
3,P73120,25,0.138,0,bacteria
4,P73120,28,0.194,0,bacteria
...,...,...,...,...,...
91825,P0DPA2,353,0.503,0,human
91826,P0DPA2,355,0.654,0,human
91827,P0DPA2,377,0.064,0,human
91828,P0DPA2,387,0.667,1,human


In [16]:
netphospan_test_with_group = netphospan_test.merge(seq_organisms.drop(['phylogeny', 'organism'], axis=1), on='id', how='inner')
print(f'Went from {len(netphospan_test)} phosphos to {len(netphospan_test_with_group)} phosphos')
netphospan_test_with_group

Went from 35459 phosphos to 30661 phosphos


Unnamed: 0,id,pos,score,y_true,group
0,A0A0H3MAG4,45,0.063,0,bacteria
1,A0A0H3MAG4,49,0.102,0,bacteria
2,A0A0H3MAG4,53,0.117,0,bacteria
3,A0A0H3MAG4,57,0.162,0,bacteria
4,A0A0H3MAG4,63,0.095,0,bacteria
...,...,...,...,...,...
30656,Q9GU50,464,0.224,0,other_animal
30657,Q9GU50,465,0.203,0,other_animal
30658,Q9GU50,467,0.172,0,other_animal
30659,Q9GU50,469,0.097,0,other_animal


In [17]:
netphos3_test_with_group = netphos3_test.merge(seq_organisms.drop(['phylogeny', 'organism'], axis=1), on='id', how='inner')
print(f'Went from {len(netphos3_test)} phosphos to {len(netphos3_test_with_group)} phosphos')
netphos3_test_with_group

Went from 107469 phosphos to 99678 phosphos


Unnamed: 0,id,pos,score,y_true,group
0,A8IQA2,5,0.399,0,plant
1,A8IQA2,6,0.661,0,plant
2,A8IQA2,11,0.918,0,plant
3,A8IQA2,22,0.449,0,plant
4,A8IQA2,25,0.668,0,plant
...,...,...,...,...,...
99673,V4Z445,2732,0.816,0,other_eukaryotes
99674,V4Z445,2743,0.660,0,other_eukaryotes
99675,V4Z445,2753,0.544,0,other_eukaryotes
99676,V4Z445,2759,0.439,0,other_eukaryotes


In [18]:
groups = ['human', 'mammalian', 'mammalian_w_human', 'other_animal', 'bacteria', 'fungal', 'plant', 'other_eukaryotes', 'virus', 'archaea']
predictors = {
    'NetPhos 4.0': netphos4_test_with_group,
    'MusiteDeep': musitedeep_test_with_group,
    'NetPhospan': netphospan_test_with_group,
    'NetPhos 3.1': netphos3_test_with_group,
}
predictor_group_aucs = {}
for predictor, df in predictors.items():
    group_aucs = {}
    for group in groups:
        if group != 'mammalian_w_human':
            group_subset = df[df['group'] == group]
        else:
            group_subset = df[df['group'].isin(('human', 'mammalian'))]
            
        if len(group_subset) > 0:
            auc = roc_auc_score(group_subset["y_true"], group_subset["score"])
            fpr, tpr, thresholds = roc_curve(group_subset["y_true"], group_subset["score"])
            roc_curve_data = pd.DataFrame({'FPR': fpr, 'TPR': tpr, 'threshold': thresholds})
            roc_curve_data.to_csv(f'roc_curves/ROC_curve_{predictor.lower().replace(" ", "_")}_{group}.csv', header=True, index=False)
        else:
            auc = 0
        group_aucs[group] = round(auc, 2)
    predictor_group_aucs[predictor] = group_aucs

aucs = pd.DataFrame(predictor_group_aucs).transpose()
aucs.to_csv('roc_curves/AUCs_group.csv', header=True, index=True)
aucs


Unnamed: 0,human,mammalian,mammalian_w_human,other_animal,bacteria,fungal,plant,other_eukaryotes,virus,archaea
NetPhos 4.0,0.75,0.81,0.77,0.88,0.84,0.86,0.83,0.93,0.73,0.8
MusiteDeep,0.68,0.78,0.72,0.82,0.65,0.79,0.78,0.78,0.95,0.82
NetPhospan,0.67,0.77,0.73,0.73,0.71,0.77,0.74,0.77,0.95,0.83
NetPhos 3.1,0.61,0.62,0.61,0.66,0.6,0.64,0.64,0.61,0.51,0.51
