In [1]:
from skempi_utils import *
from scipy.stats import pearsonr

In [2]:
df = skempi_df
df_multi = df[~np.asarray([len(s)>8 for s in df.Protein])]
s_multi = set([s[:4] for s in df_multi.Protein])
s_groups = set([s[:4] for s in G1 + G2 + G3 + G4 + G5])
len(s_multi & s_groups), len(s_multi), len(s_groups)
df_multi.head()

Unnamed: 0,Protein,Mutation(s)_PDB,Mutation(s)_cleaned,Location(s),Hold_out_type,Hold_out_proteins,Affinity_mut (M),Affinity_wt (M),DDG,Reference,...,Temperature,kon_mut (M^(-1)s^(-1)),kon_wt (M^(-1)s^(-1)),koff_mut (s^(-1)),koff_wt (s^(-1)),dH_mut (kcal mol^(-1)),dH_wt (kcal mol^(-1)),dS_mut (cal mol^(-1) K^(-1)),dS_wt (cal mol^(-1) K^(-1)),Notes
0,1CSE_E_I,LI45G,LI38G,COR,PI,PI,5.26e-11,1.12e-12,2.280577,9048543,...,294,,,,,,,,,
1,1CSE_E_I,LI45S,LI38S,COR,PI,PI,8.33e-12,1.12e-12,1.188776,9048543,...,294,,,,,,,,,
2,1CSE_E_I,LI45P,LI38P,COR,PI,PI,1.02e-07,1.12e-12,6.765446,9048543,...,294,,,,,,,,,
3,1CSE_E_I,LI45I,LI38I,COR,PI,PI,1.72e-10,1.12e-12,2.982502,9048543,...,294,,,,,,,,,
4,1CSE_E_I,LI45D,LI38D,COR,PI,PI,1.92e-09,1.12e-12,4.411843,9048543,...,294,,,,,,,,,


In [3]:
from sklearn.preprocessing import StandardScaler
from itertools import combinations as comb
from sklearn.externals import joblib
import numpy as np

def evaluate(group_str, y_true, y_pred, ix):
    y_pred_pos = y_pred[ix == 0]
    y_pred_neg = y_pred[ix == 1]
    y_true_pos = y_true[ix == 0]
    y_true_neg = y_true[ix == 1]
    cor_all, _ = pearsonr(y_true, y_pred)
    cor_pos, _ = pearsonr(y_true_pos, y_pred_pos)
    cor_neg, _ = pearsonr(y_true_neg, y_pred_neg)
    print("[%s:%d] cor_all:%.3f, cor_pos:%.3f, cor_neg:%.3f" % (group_str, len(y_true), cor_all, cor_pos, cor_neg))
    return cor_all, cor_pos, cor_neg

def run_cv_test(X, y, ix, get_regressor, modelname, normalize=1):
    gt, preds, indx, cors = [], [], [], []
    groups = [G1, G2, G3, G4, G5]
    prots = G1 + G2 + G3 + G4 + G5
    for i, pair in enumerate(comb(range(NUM_GROUPS), 2)):
        group = groups[pair[0]] + groups[pair[1]]
        g1, g2 = np.asarray(pair) + 1
        indx_tst = (ix[:, 0] == g1) | (ix[:, 0]  == g2)
        indx_trn = np.logical_not(indx_tst)
        y_trn = y[indx_trn]
        y_true = y[indx_tst]
        X_trn = X[indx_trn]
        X_tst = X[indx_tst]
        if normalize == 1:
            scaler = StandardScaler()
            scaler.fit(X_trn)
            X_trn = scaler.transform(X_trn)
            X_tst = scaler.transform(X_tst)
        regressor = get_regressor()
        regressor.fit(X_trn, y_trn)
        joblib.dump(regressor, 'models/%s%s.pkl' % (modelname, i))
        regressor = joblib.load('models/%s%s.pkl' % (modelname, i))
        y_pred = regressor.predict(X_tst)
        cor, pos, neg = evaluate("G%d,G%d" % (g1, g2), y_true, y_pred, ix[indx_tst, 1])
        cors.append([cor, pos, neg])
        indx.extend(ix[indx_tst, 1])
        preds.extend(y_pred)
        gt.extend(y_true)
    return [np.asarray(a) for a in [gt, preds, indx, cors]]

def run_cv_test_ensemble(X, y, ix, alpha=0.5, normalize=1):
    gt, preds, indx, cors = [], [], [], []
    groups = [G1, G2, G3, G4, G5]
    prots = G1 + G2 + G3 + G4 + G5
    for i, pair in enumerate(comb(range(NUM_GROUPS), 2)):
        group = groups[pair[0]] + groups[pair[1]]
        g1, g2 = np.asarray(pair) + 1
        indx_tst = (ix[:, 0] == g1) | (ix[:, 0]  == g2)
        indx_trn = np.logical_not(indx_tst)
        y_trn = y[indx_trn]
        y_true = y[indx_tst]
        X_trn = X[indx_trn]
        X_tst = X[indx_tst]
        svr = joblib.load('models/svr%d.pkl' % i)
        rfr = joblib.load('models/rfr%d.pkl' % i)
        if normalize == 1:
            scaler = StandardScaler()
            scaler.fit(X_trn)
            X_trn = scaler.transform(X_trn)
            X_tst = scaler.transform(X_tst)
        y_pred_svr = svr.predict(X_tst)
        y_pred_rfr = rfr.predict(X_tst)
        y_pred = alpha * y_pred_svr + (1-alpha) * y_pred_rfr
        cor, pos, neg = evaluate("G%d,G%d" % (g1, g2), y_true, y_pred, ix[indx_tst, 1])
        cors.append([cor, pos, neg])
        indx.extend(ix[indx_tst, 1])
        preds.extend(y_pred)
        gt.extend(y_true)
    return [np.asarray(a) for a in [gt, preds, indx, cors]]

def records_to_xy(skempi_records, load_neg=True):
    data = []
    for record in tqdm(skempi_records, desc="records processed"):
        r = record
        assert r.struct is not None
        data.append([r.features(True), [r.ddg], [r.group, r.is_minus]])
        if not load_neg: continue 
        rr = reversed(record)
        assert rr.struct is not None
        data.append([rr.features(True), [rr.ddg], [rr.group, rr.is_minus]])
    X, y, ix = [np.asarray(d) for d in zip(*data)]
    return X, y, ix

In [4]:
def get_temperature_array(records, agg=np.min):
    arr = []
    pbar = tqdm(range(len(skempi_df)), desc="row processed")
    for i, row in skempi_df.iterrows():
        arr_obs_mut = []
        for mutation in row["Mutation(s)_cleaned"].split(','):
            mut = Mutation(mutation)
            res_i, chain_id = mut.i, mut.chain_id
            t = tuple(row.Protein.split('_'))
            skempi_record = records[t]
            res = skempi_record[chain_id][res_i]
            temps = [a.temp for a in res.atoms]
            arr_obs_mut.append(np.mean(temps))
        arr.append(agg(arr_obs_mut))
        pbar.update(1)
    pbar.close()
    return arr

skempi_records = load_skempi_structs(pdb_path="../data/pdbs_n", compute_dist_mat=False)
temp_arr = get_temperature_array(skempi_records, agg=np.min)

skempi structures processed: 100%|██████████| 158/158 [00:26<00:00,  5.89it/s]
row processed: 100%|██████████| 3047/3047 [00:00<00:00, 3520.53it/s]


In [5]:
skempi_structs = load_skempi_structs("../data/pdbs", compute_dist_mat=False)
skempi_records = load_skempi_records(skempi_structs)

skempi structures processed: 100%|██████████| 158/158 [02:09<00:00,  1.22it/s]
skempi records processed: 100%|██████████| 3047/3047 [00:00<00:00, 3783.58it/s]


In [6]:
# X_pos, y_pos, ix_pos = records_to_xy(skempi_records)
# X_pos.shape, y_pos.shape, ix_pos.shape

In [7]:
X_, y_, ix_ = records_to_xy(skempi_records)

records processed: 100%|██████████| 3047/3047 [4:52:51<00:00,  5.77s/it]


In [9]:
X = X_[:, :]
# X = np.concatenate([X.T, [temp_arr]], axis=0).T
y = y_[:, 0]
ix = ix_
X.shape, y.shape, ix.shape

((6094, 9), (6094,), (6094, 2))

In [10]:
print("----->SVR")
from sklearn.svm import SVR
def get_regressor(): return SVR(kernel='rbf')
gt, preds, indx, cors = run_cv_test(X, y, ix, get_regressor, 'svr', normalize=1)
cor1, _, _ = evaluate("CAT", gt, preds, indx)
print(np.mean(cors, axis=0))

print("----->RFR")
from sklearn.ensemble import RandomForestRegressor
def get_regressor(): return RandomForestRegressor(n_estimators=50, random_state=0)
gt, preds, indx, cors = run_cv_test(X, y, ix, get_regressor, 'rfr', normalize=1)
cor2, _, _ = evaluate("CAT", gt, preds, indx)
print(np.mean(cors, axis=0))

# alpha = cor1/(cor1+cor2)
alpha = 0.5
print("----->%.2f*SVR + %.2f*RFR" % (alpha, 1-alpha))
gt, preds, indx, cors = run_cv_test_ensemble(X, y, ix, normalize=1)
cor, _, _ = evaluate("CAT", gt, preds, indx)
print(np.mean(cors, axis=0))

----->SVR
[G1,G2:1468] cor_all:0.572, cor_pos:0.371, cor_neg:0.385
[G1,G3:1580] cor_all:0.481, cor_pos:0.318, cor_neg:0.362
[G1,G4:1630] cor_all:0.438, cor_pos:0.261, cor_neg:0.314
[G1,G5:1820] cor_all:0.559, cor_pos:0.416, cor_neg:0.429
[G2,G3:1468] cor_all:0.594, cor_pos:0.359, cor_neg:0.485
[G2,G4:1518] cor_all:0.542, cor_pos:0.330, cor_neg:0.394
[G2,G5:1708] cor_all:0.631, cor_pos:0.472, cor_neg:0.520
[G3,G4:1630] cor_all:0.481, cor_pos:0.295, cor_neg:0.371
[G3,G5:1820] cor_all:0.614, cor_pos:0.474, cor_neg:0.517
[G4,G5:1870] cor_all:0.578, cor_pos:0.426, cor_neg:0.476
[CAT:16512] cor_all:0.541, cor_pos:0.365, cor_neg:0.414
[ 0.54895502  0.37220431  0.42546674]
----->RFR
[G1,G2:1468] cor_all:0.606, cor_pos:0.480, cor_neg:0.328
[G1,G3:1580] cor_all:0.644, cor_pos:0.578, cor_neg:0.443
[G1,G4:1630] cor_all:0.548, cor_pos:0.457, cor_neg:0.357
[G1,G5:1820] cor_all:0.642, cor_pos:0.584, cor_neg:0.474
[G2,G3:1468] cor_all:0.654, cor_pos:0.462, cor_neg:0.488
[G2,G4:1518] cor_all:0.560, cor