In [9]:
import pandas as pd
from Bio import SeqIO
import re

In [10]:
record_dict = SeqIO.to_dict(SeqIO.parse("../fda/uniprot_sprot.fasta", "fasta"))
kv = list(record_dict.items())
record_dict.clear()
for k, v in kv :
    new_k = (k.split('|')[1])
    record_dict[new_k] = str(v.seq)

In [11]:
df = pd.read_csv('Mutation_perturbation_model.csv')

In [12]:
mut_pattern = r'([A-Z])(\d+)([A-Z])'

wt_seq1s = []
wt_seq2s = []
mut_seq1s = []
mut_seq2s = []
targets = []

for i, row in df.iterrows():

    id1 = row['Mutation UPID']
    id2 = row['Interactor UPID']
    target = row['Y2H_score']

    mut = row['Mutation']
    
    seq1 = record_dict[id1]
    seq2 = record_dict[id2]

    mut_match = re.match(mut_pattern, mut)
    wt_res = mut_match.group(1) 
    res_num = int(mut_match.group(2))
    mut_res = mut_match.group(3)

    if wt_res != seq1[res_num-1]:
        print(f'Failed on index {i}')
        continue
    else:
        seq1_l = list(seq1)
        seq1_l[res_num-1] = mut_res
        mut_seq1 = ''.join(seq1_l)

    if seq1 == mut_seq1:
        print(f'Skipping on index {i}')
        continue

    wt_seq1s.append(seq1)
    wt_seq2s.append(seq2) 
    mut_seq1s.append(mut_seq1)
    mut_seq2s.append(seq2)
    targets.append(target)

Failed on index 34
Skipping on index 166
Skipping on index 391
Failed on index 582
Failed on index 1100
Skipping on index 1293
Failed on index 1409
Skipping on index 1670
Skipping on index 1703
Failed on index 1882
Skipping on index 1898
Skipping on index 2413
Skipping on index 2517
Skipping on index 2999
Skipping on index 3114
Skipping on index 3365


In [13]:
mut_df = pd.DataFrame({'seq1': wt_seq1s, 
                       'seq2': wt_seq2s, 
                       'seq1_mut': mut_seq1s,
                       'seq2_mut': mut_seq2s,
                       'target': targets})

In [14]:
mut_df.to_csv('processed_data_cs.csv')

In [15]:
mut_df['target'].value_counts()

0.0    2573
1.0     833
Name: target, dtype: int64

In [16]:
onco_ppi = pd.read_excel('Table S3.xlsx')

In [17]:
#onco_ppi = onco_ppi[onco_ppi['UniProt_ID_a'] != onco_ppi['UniProt_ID_b']]

In [18]:
mut_pattern = r'([A-Z])(\d+)([A-Z])'

wt_seq1s = []
wt_seq2s = []
mut_seq1s = []
mut_seq2s = []
id1s = []
id2s = []
targets = []

for i, row in onco_ppi.iterrows():
    full_id1 = row['UniProt_ID_a']
    id2 = row['UniProt_ID_b']
    target = row['Growth_score']

    full_id1_split = row['UniProt_ID_a'].split('-')
    
    id1 = full_id1_split[0]
    seq1 = record_dict[id1]
    seq2 = record_dict[id2]

    if len(full_id1_split) > 1:
        mut = full_id1_split[1]
        mut_match = re.match(mut_pattern, mut)
        wt_res = mut_match.group(1) 
        res_num = int(mut_match.group(2))
        mut_res = mut_match.group(3)

        if wt_res != seq1[res_num-1]:
            print(f'Failed on index {i}')
            continue
        else:
            seq1_l = list(seq1)
            seq1_l[res_num-1] = mut_res
            mut_seq1 = ''.join(seq1_l)
    else:
        continue

    if seq1 == mut_seq1:
        print(f'Skipping on index {i}')
        continue

    id1s.append(id1)
    id2s.append(id2)
    wt_seq1s.append(seq1)
    wt_seq2s.append(seq2) 
    mut_seq1s.append(mut_seq1)
    mut_seq2s.append(seq2)
    targets.append(target)

In [19]:
onco_ppi_df = pd.DataFrame({'id1': id1s, 'id2': id2s,
                       'seq1': wt_seq1s, 
                       'seq2': wt_seq2s, 
                       'seq1_mut': mut_seq1s,
                       'seq2_mut': mut_seq2s,
                       'target_og': targets})

In [20]:
onco_ppi_df['target'] = onco_ppi_df['target_og'].apply(lambda x: 1 if x>=3 else 0)

In [21]:
onco_ppi_df.to_csv('processed_data_val_cs.csv')

In [22]:
onco_ppi_df_un = onco_ppi_df[onco_ppi_df['id1'] != onco_ppi_df['id2']]

In [23]:
onco_ppi_df_un.to_csv('processed_data_test_cs.csv')

In [1]:
import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import PowerTransformer, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_recall_curve, auc
from torch.utils.data import Dataset
from torch import nn
from tqdm import tqdm
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.neural_network import MLPRegressor, MLPClassifier

In [7]:


test_df = pd.read_csv('processed_data_val_cs.csv')

test_targets = test_df['target'].values.astype(np.int32)
test_targets_og = test_df['target_og'].values.astype(np.int32)

# train, test = convert_train_test_features(train, test)

# model = MLPClassifier()
# param_grid = {
#     "activation": ["relu"],
#     "alpha": [0.0001],
#     "learning_rate": ["adaptive"],
#     "solver": ["adam"],
#     "learning_rate_init": [0.001],
#     "max_iter": [100],
#     "hidden_layer_sizes": [
#         (128), (256,), (512)
#         ],
#     "early_stopping": [True],
#     "random_state": [0],
#     "validation_fraction": [0.3],
#     "tol": [1e-4]}

# cv = KFold(n_splits=5, shuffle=True, random_state=0)

# grid_search = GridSearchCV(estimator=model, 
#                            param_grid=param_grid, 
#                            cv=cv, 
#                            verbose=10, 
#                            scoring='f1')

# grid_search.fit(train, train_targets)

# # Best model found by GridSearchCV
# best_model = grid_search.best_estimator_
    
# # Evaluate the best model on the outer test set
# Y_pred = best_model.predict_proba(test)

In [2]:
test_preds = np.load('best_preds_100_rep.npy')

In [5]:
test_preds[test_preds > 0.5] = 1
test_preds[test_preds <= 0.5] = 0

In [6]:
test_preds_agg = test_preds.mean(axis=0)

In [8]:
test_df['pred'] = test_preds_agg

In [12]:
test_df.to_csv('oncoppi_mutation_results.csv')

In [11]:
test_df[test_df['id1'] != test_df['id2']]

Unnamed: 0.1,Unnamed: 0,id1,id2,seq1,seq2,seq1_mut,seq2_mut,target_og,target,pred
12,12,Q99081,O75575,MNPQQQRMAAIGTDKELSDLLDFSAMFSPPVNSGKTRPTTLGSSQF...,MEVKDANSALLSNYEVFQLLTDLKEQRKESGKNKHSSGQQNLNTIT...,MNPQQQRMAAIGTDKELSDLLDFSAMFSPPVNSGKTRPTTLGSSQF...,MEVKDANSALLSNYEVFQLLTDLKEQRKESGKNKHSSGQQNLNTIT...,0,0,0.3
13,13,Q99081,O75575,MNPQQQRMAAIGTDKELSDLLDFSAMFSPPVNSGKTRPTTLGSSQF...,MEVKDANSALLSNYEVFQLLTDLKEQRKESGKNKHSSGQQNLNTIT...,MNPQQQRMAAIGTDKELSDLLDFSAMFSPPVNSGKTRPTTLGSSQF...,MEVKDANSALLSNYEVFQLLTDLKEQRKESGKNKHSSGQQNLNTIT...,0,0,0.29
14,14,Q13163,Q13164,MLWLALGPFPAMENQVLVIRIKIPNSGAVDWTVHSGPQLLFRDVLD...,MAEPLKEEDGEDGSAEPPGPVKAEPAHTAASVAAKNLALLKARSFD...,MLWLALGPFPAMENQVLVIRIKIPNSGTVDWTVHSGPQLLFRDVLD...,MAEPLKEEDGEDGSAEPPGPVKAEPAHTAASVAAKNLALLKARSFD...,4,1,0.74
15,15,Q13163,Q13164,MLWLALGPFPAMENQVLVIRIKIPNSGAVDWTVHSGPQLLFRDVLD...,MAEPLKEEDGEDGSAEPPGPVKAEPAHTAASVAAKNLALLKARSFD...,MLWLALGPFPAMENQVLVIRIKIPNSGAVDWTVHSGPQLLFRDVLD...,MAEPLKEEDGEDGSAEPPGPVKAEPAHTAASVAAKNLALLKARSFD...,3,1,0.78
16,16,P61224,Q12967,MREYKLVVLGSGGVGKSALTVQFVQGIFVEKYDPTIEDSYRKQVEV...,MVQRMWAEAAGPAGGAEPLFPGSRRSRSVWDAVRLEVGVPDSCPVV...,MREYKLVVLGSGGVGKSALTVQFVQGIFVEKYDPMIEDSYRKQVEV...,MVQRMWAEAAGPAGGAEPLFPGSRRSRSVWDAVRLEVGVPDSCPVV...,0,0,0.16
17,17,P61586,P52565,MAAIRKKLVIVGDGACGKTCLLIVFSKDQFPEVYVPTVFENYVADI...,MAEQEPTAEQLAQIAAENEEDEHSVNYKPPAQKSIQEIQELDKDDE...,MAAIRKKLVIVGDGACGKTCLLIVFSKDQFPEVYVPTVFENYVADI...,MAEQEPTAEQLAQIAAENEEDEHSVNYKPPAQKSIQEIQELDKDDE...,0,0,0.9
18,18,P09917,Q9Y6D9,MPSYTVTVATGSQWFAGTDDYIYLSLVGSAGCSEKHLLDKPFYNDF...,MEDLGENTMVLSTLRSLNNFISQRVEGGSGLDISTSAPGSLQMQYQ...,MPSYTVTVATGSQWFAGTDDYIYLSLVGSAGCSEKHLLDKPFYNDF...,MEDLGENTMVLSTLRSLNNFISQRVEGGSGLDISTSAPGSLQMQYQ...,0,0,0.14
19,19,O75558,Q06455,MKDRLAELLDLSKQYDQQFPDGDDEFDSPHEDIVFETDHILESLYR...,MISVKRNTWRALSLVIGDCRKKGNFEYCQDRTEKHSTMPDSPVDVK...,MKDRLAELLDLSKQYDQQFPDGDDEFDSPHEDIVFETDHILESLYR...,MISVKRNTWRALSLVIGDCRKKGNFEYCQDRTEKHSTMPDSPVDVK...,0,0,0.52
21,21,Q8IX15,Q9UH73,MVRGWEPPPGLDCAISEGHKSEGTMPPNKEASGLSSSPAGLICLPP...,MFGIQESIQRSGSSMKEEPLGSGMNAVRTWMQGAGVLDANTAAQSG...,MVRGWEPPPGLDCAISEGHKSEGTMPPNKEASGLSSSPAGLICLPP...,MFGIQESIQRSGSSMKEEPLGSGMNAVRTWMQGAGVLDANTAAQSG...,0,0,0.0
22,22,P30153,Q9Y534,MAAADGDDSLYPIAVLIDELRNEDVQLRLNSIKKLSTIALALGVER...,MTSESTSPPVVPPLHSPKSPVWPTFPFHREGSRVWERGGVPPRDLP...,MAAADGDDSLYPIAVLIDELRNEDVQLRLNSIKKLSTIALALGVER...,MTSESTSPPVVPPLHSPKSPVWPTFPFHREGSRVWERGGVPPRDLP...,3,1,0.97
