In [81]:
import numpy as np
import pandas as pd
import difflib
import Levenshtein
import xgboost as xgb
import sklearn
from sklearn.model_selection import train_test_split
import scipy.stats

In [57]:
#### This function find the wildtype of a group of enzymes by 
#### taking the most frequent amino acids as the wildtype.
#### Return: the sequence of wildtype with the same formula 
#### as the sequence in training_data.
def get_wt(X):
    length = {}
    for i in X:
        l = len(i)
        if l not in length:
            length[l] = 1
        else:
            length[l] += 1
    wt_len = max(length, key=length.get)
    wild_type = []
    for i in range(wt_len):
        lt = []
        for j in X:
            if (len(j)>i):
                lt.append(j[i])
        wild_type.append(get_most(lt))
    return ''.join(wild_type)
#### This function finds the wildtype based on the sample have the 
#### smallest sum of Levenshtein distance on its amino acid sequence
#### It returns the index of the wildtype in the group.
def get_wt_sup(X):
    l = len(X)
    lenlist = []
    for i in range(l):
        s = 0
        for j in range(l):
            s += Levenshtein.distance(X[i],X[j])
        lenlist.append(s)
    lenlist = np.array(lenlist)
    return np.argmin(lenlist)

#### This function returns the key with maximum value in a dictionary.
def get_most(X):
    c_max = 0
    n = X[0]
    for i in X:
        c = X.count(i)
        if c > c_max:
            n = i
            c_max = c
    return n


In [3]:
#### We use difflib.ndiff from package difflib to
#### find the difference between two strings. 
#### These two functions will find the mutation type 
#### from the result originated by difflib.ndiff
#### And transform it to a more readable form
#### Returns: a dictionary specifies the mutation type
#### and its amount in a sequence.
def find_mutation(x,y):
    if (len(x)!=len(y)):
        d = find_mut(x,y)
        if (len(d)==1):
            return d
    m = int(len(x)/2)
    d1 = find_mut(x[0:m],y[0:m])
    d2 = find_mut(x[m:],y[m:])
    d3 = {**d1,**d2}
    return d3
def find_mut(x,y):    
    mutation = {}
    tmp = ''
    has_minus = False
    for i, j in enumerate(difflib.ndiff(x,y)):
        if ((j[0]=="-") and has_minus==False):
            has_minus = True
            tmp = j[0]+j[2]
        elif ((j[0]=="+")and has_minus):
            has_minus = False
            tmp+=j[0]+j[2]
            if tmp not in mutation:
                mutation[tmp]=1
            else:
                mutation[tmp]+=1
            tmp = ''
        elif ((j[0]!="+")and(j[0]!="-")and has_minus):
            has_minus = False
            if tmp not in mutation:
                mutation[tmp]=1
            else:
                mutation[tmp]+=1
            tmp = ''
        elif ((j[0]=="-")and has_minus):
            has_minus = False
            if tmp not in mutation:
                mutation[tmp]=1
            else:
                mutation[tmp]+=1
            tmp = j[0]+j[2]
    return mutation  

In [4]:
#### This function finds the positions of mutations, 
#### which helps us find the neighbouring acids around the mutation.
def find_mut_pst(x,y):    
    mutation = []
    has_minus = False
    for i, j in enumerate(difflib.ndiff(x,y)):
        if ((j[0]=="-") and has_minus==False):
            has_minus = True
        elif ((j[0]=="+")and has_minus):
            has_minus = False
            mutation.append(i)
        elif ((j[0]!="+")and(j[0]!="-")and has_minus):
            has_minus = False
            mutation.append(i)
        elif ((j[0]=="-")and has_minus):
            has_minus = False
            mutation.append(i)
    return np.array(mutation)
#### This function find the count of each amino acid in the neighbor of 
def get_nba(x,wt,nbl = 20):
    l = len(x)
    lwt = len(wt)
    freq = {}
    for i in range(l):
        left = max(x[i]-20, 0)
        right = min(lwt, x[i]+20)
        for i in range(left,right):
            ca = wt[i]
            if ca in freq:
                freq[ca]+=1/l
            elif ca not in freq:
                freq[ca]=1/l
    return freq



In [61]:
test_data = pd.read_csv('test.csv')
cdata = pd.read_csv('grouped_sample_no_pH_one_muta.csv')
seq = test_data['protein_sequence']

In [6]:
#### Generating the wild type of test samples.
test_wt = get_wt(seq)

In [7]:
#### Generating the data frame of mutation occurences of test data.

df = pd.DataFrame()
for i in seq:
    freq = find_mutation(test_wt, i)
    dff = pd.DataFrame(freq,index=[1])
    df = pd.concat([df,dff],ignore_index = True)
dm = df.fillna(0)

In [9]:
#### Generating the data frame of neighboring amino acids occurences of test data.

df2 = pd.DataFrame()
for i in seq:
    x = find_mut_pst(test_wt, i)
    freq = get_nba(x, test_wt)
    dff = pd.DataFrame(freq, index=[1])
    df2 = pd.concat([df2,dff], ignore_index = True)
dm2 = df2.fillna(0)



In [10]:
test_design=pd.concat([dm,dm2],axis = 1,ignore_index=True)

In [None]:
test_design.to_csv('test_design.csv')

In [62]:
#### Generating the design matrix of training data

group_num = max(cdata['group'])
X1 = pd.DataFrame()
X2 = pd.DataFrame()
y = np.array([])
for i in range(0,group_num+1):
    group_data = cdata[cdata['group']==i]
    g_seq = group_data['protein_sequence']
    g_seq = np.array(g_seq)
    g_y = group_data['tm']
    g_y = np.array(g_y)
    wt_index = get_wt_sup(g_seq)
    g_y = g_y - g_y[wt_index]
    y = np.append(y, g_y)
    wt = g_seq[wt_index]
    for i in g_seq:
        freq = find_mutation(wt,i)
        df = pd.DataFrame(freq,index=[1])
        X1 = pd.concat([X1,df],ignore_index = True)
        x = find_mut_pst(wt, i)
        freq2 = get_nba(x, wt)
        dff = pd.DataFrame(freq2, index=[1])
        X2 = pd.concat([X2,dff], ignore_index = True)
X1 = X1.fillna(0)
X2 = X2.fillna(0)
train_X = pd.concat([X1,X2],axis = 1, ignore_index = True)
train_y = y


In [74]:
####Ensuring design matrix(features) of training and test data have the same size.
check = pd.concat([X1, dm],ignore_index=True).fillna(0)
dim1 = len(y)
dim2 = cdata.shape[0]
X1 = check.iloc[0:dim1]
dm = check.iloc[dim1:(dim1+dim2)]
check2 = pd.concat([X2, dm2],ignore_index=True).fillna(0)
X2 = check2.iloc[0:dim1]
dm2 = check2.iloc[dim1:(dim1+dim2)]

In [89]:
####XGBoost Upgraded (Model 2)

xgb_parms = { 
    'max_depth':5, 
    'learning_rate':0.001, 
    'subsample':0.6,
    'colsample_bytree':0.3, 
    'eval_metric':'rmse',
    'objective':'reg:squarederror',
    'random_state':123
}
X_train, X_valid, y_train, y_valid = train_test_split(train_X, y, test_size=0.1, random_state=123)
model = xgb
dtrain = xgb.DMatrix(data=X_train, label=y_train)
dvalid = xgb.DMatrix(data=X_valid, label=y_valid)
dtest = xgb.DMatrix(data=test_X)

In [90]:
m2 = xgb.train(xgb_parms, dtrain=dtrain,evals=[(dtrain,"train"),(dvalid, "valid")],
                 num_boost_round=10000,
                 early_stopping_rounds = 100,
                 verbose_eval = 100)


[0]	train-rmse:12.17984	valid-rmse:12.43337
[100]	train-rmse:11.88096	valid-rmse:12.15961
[200]	train-rmse:11.61433	valid-rmse:11.90944
[300]	train-rmse:11.36296	valid-rmse:11.67640
[400]	train-rmse:11.13389	valid-rmse:11.46623
[500]	train-rmse:10.92103	valid-rmse:11.26987
[600]	train-rmse:10.72688	valid-rmse:11.09278
[700]	train-rmse:10.54345	valid-rmse:10.92550
[800]	train-rmse:10.38090	valid-rmse:10.77967
[900]	train-rmse:10.23631	valid-rmse:10.64478
[1000]	train-rmse:10.09328	valid-rmse:10.51771
[1100]	train-rmse:9.96293	valid-rmse:10.40306
[1200]	train-rmse:9.84414	valid-rmse:10.29503
[1300]	train-rmse:9.73303	valid-rmse:10.19598
[1400]	train-rmse:9.62945	valid-rmse:10.10379
[1500]	train-rmse:9.53482	valid-rmse:10.02506
[1600]	train-rmse:9.44661	valid-rmse:9.94923
[1700]	train-rmse:9.36949	valid-rmse:9.88407
[1800]	train-rmse:9.29581	valid-rmse:9.81994
[1900]	train-rmse:9.22547	valid-rmse:9.76049
[2000]	train-rmse:9.15941	valid-rmse:9.70514
[2100]	train-rmse:9.09754	valid-rmse:9.6

In [84]:
valid_pred = m2.predict(dvalid)
scipy.stats.spearmanr(valid_pred, y_valid)

SpearmanrResult(correlation=0.6486030622950832, pvalue=2.7605737795994982e-53)

In [87]:
####Grid search on the XGBoost model.

max_len = [3,4,5]
lr = [0.001,0.005,0.01]
subs = [0.5,0.6,0.7]
cols = [0.2,0.25,0.3]
rmax= 0
index = [3,0.001,0.5,0.15]
for i in max_len:
    for j in lr:
        for k in subs:
            for l in cols:
                xgb_parms = { 
                    'max_depth':i, 
                    'learning_rate':j, 
                    'subsample':k,
                    'colsample_bytree':l, 
                    'eval_metric':'rmse',
                    'objective':'reg:squarederror',
                    'random_state':123,
                    "reg_lambda":0.1
                }
                X_train, X_valid, y_train, y_valid = train_test_split(train_X, train_y, test_size=0.1, random_state=123)
                model = xgb
                dtrain = xgb.DMatrix(data=X_train, label=y_train)
                dvalid = xgb.DMatrix(data=X_valid, label=y_valid)

                m2 = xgb.train(xgb_parms, dtrain=dtrain,evals=[(dtrain,"train"),(dvalid, "valid")],
                                 num_boost_round=5000,
                                 early_stopping_rounds = 100,
                                 verbose_eval = 100)

                pred=m2.predict(dvalid)
                result = scipy.stats.spearmanr(pred, y_valid)[0]
                print(i,j,k,l)
                print(result)
                if result > rmax:
                    index[0]=i
                    index[1]=j
                    index[2]=k
                    index[3]=l
                    rmax = result


[0]	train-rmse:12.18183	valid-rmse:12.43504
[100]	train-rmse:12.01955	valid-rmse:12.27673
[200]	train-rmse:11.85986	valid-rmse:12.12117
[300]	train-rmse:11.71437	valid-rmse:11.97871
[400]	train-rmse:11.58127	valid-rmse:11.84785
[500]	train-rmse:11.44999	valid-rmse:11.71885
[600]	train-rmse:11.33185	valid-rmse:11.60310
[700]	train-rmse:11.21754	valid-rmse:11.49582
[800]	train-rmse:11.10816	valid-rmse:11.38886
[900]	train-rmse:11.00998	valid-rmse:11.29090
[1000]	train-rmse:10.91109	valid-rmse:11.19879
[1100]	train-rmse:10.81930	valid-rmse:11.10887
[1200]	train-rmse:10.73679	valid-rmse:11.02982
[1300]	train-rmse:10.65405	valid-rmse:10.94772
[1400]	train-rmse:10.57731	valid-rmse:10.87148
[1500]	train-rmse:10.50348	valid-rmse:10.80059
[1600]	train-rmse:10.43289	valid-rmse:10.73127
[1700]	train-rmse:10.36520	valid-rmse:10.66764
[1800]	train-rmse:10.30223	valid-rmse:10.60637
[1900]	train-rmse:10.24343	valid-rmse:10.55022
[2000]	train-rmse:10.18501	valid-rmse:10.49395
[2100]	train-rmse:10.1323

[2400]	train-rmse:9.97328	valid-rmse:10.31080
[2500]	train-rmse:9.92779	valid-rmse:10.26395
[2600]	train-rmse:9.88481	valid-rmse:10.22375
[2700]	train-rmse:9.84290	valid-rmse:10.18502
[2800]	train-rmse:9.80394	valid-rmse:10.14804
[2900]	train-rmse:9.76425	valid-rmse:10.10834
[3000]	train-rmse:9.72807	valid-rmse:10.07474
[3100]	train-rmse:9.69296	valid-rmse:10.04290
[3200]	train-rmse:9.65829	valid-rmse:10.00892
[3300]	train-rmse:9.62651	valid-rmse:9.97815
[3400]	train-rmse:9.59582	valid-rmse:9.94857
[3500]	train-rmse:9.56506	valid-rmse:9.91878
[3600]	train-rmse:9.53566	valid-rmse:9.89251
[3700]	train-rmse:9.50659	valid-rmse:9.86658
[3800]	train-rmse:9.47748	valid-rmse:9.84070
[3900]	train-rmse:9.45092	valid-rmse:9.81690
[4000]	train-rmse:9.42436	valid-rmse:9.79535
[4100]	train-rmse:9.39713	valid-rmse:9.77254
[4200]	train-rmse:9.37184	valid-rmse:9.75077
[4300]	train-rmse:9.34691	valid-rmse:9.73058
[4400]	train-rmse:9.32121	valid-rmse:9.70867
[4500]	train-rmse:9.29762	valid-rmse:9.68723
[

[4800]	train-rmse:9.23437	valid-rmse:9.62726
[4900]	train-rmse:9.21342	valid-rmse:9.60870
[4999]	train-rmse:9.19362	valid-rmse:9.59278
3 0.001 0.7 0.2
0.6111319726675388
[0]	train-rmse:12.18136	valid-rmse:12.43401
[100]	train-rmse:12.00417	valid-rmse:12.26060
[200]	train-rmse:11.83454	valid-rmse:12.09593
[300]	train-rmse:11.67735	valid-rmse:11.94007
[400]	train-rmse:11.53562	valid-rmse:11.79978
[500]	train-rmse:11.39626	valid-rmse:11.66549
[600]	train-rmse:11.26996	valid-rmse:11.54123
[700]	train-rmse:11.14780	valid-rmse:11.42433
[800]	train-rmse:11.03595	valid-rmse:11.31455
[900]	train-rmse:10.93440	valid-rmse:11.21487
[1000]	train-rmse:10.83313	valid-rmse:11.11940
[1100]	train-rmse:10.74046	valid-rmse:11.03182
[1200]	train-rmse:10.65302	valid-rmse:10.94861
[1300]	train-rmse:10.57065	valid-rmse:10.86923
[1400]	train-rmse:10.49033	valid-rmse:10.79216
[1500]	train-rmse:10.41782	valid-rmse:10.72349
[1600]	train-rmse:10.34890	valid-rmse:10.65704
[1700]	train-rmse:10.28581	valid-rmse:10.59

[2200]	train-rmse:8.35860	valid-rmse:8.94145
[2300]	train-rmse:8.31671	valid-rmse:8.91836
[2400]	train-rmse:8.27984	valid-rmse:8.89569
[2500]	train-rmse:8.24216	valid-rmse:8.87413
[2600]	train-rmse:8.20617	valid-rmse:8.85661
[2700]	train-rmse:8.16960	valid-rmse:8.84247
[2800]	train-rmse:8.13467	valid-rmse:8.83026
[2900]	train-rmse:8.10156	valid-rmse:8.82075
[3000]	train-rmse:8.07064	valid-rmse:8.80950
[3100]	train-rmse:8.04120	valid-rmse:8.79363
[3200]	train-rmse:8.01502	valid-rmse:8.78314
[3300]	train-rmse:7.98687	valid-rmse:8.77281
[3400]	train-rmse:7.96185	valid-rmse:8.76180
[3500]	train-rmse:7.93798	valid-rmse:8.75556
[3600]	train-rmse:7.91383	valid-rmse:8.75461
[3671]	train-rmse:7.89694	valid-rmse:8.75343
3 0.005 0.5 0.25
0.6484933448364583
[0]	train-rmse:12.17067	valid-rmse:12.42250
[100]	train-rmse:11.36336	valid-rmse:11.65207
[200]	train-rmse:10.79539	valid-rmse:11.09415
[300]	train-rmse:10.38540	valid-rmse:10.68749
[400]	train-rmse:10.07424	valid-rmse:10.37949
[500]	train-rmse

[3000]	train-rmse:8.04790	valid-rmse:8.82464
[3100]	train-rmse:8.01850	valid-rmse:8.81042
[3200]	train-rmse:7.99001	valid-rmse:8.79629
[3300]	train-rmse:7.96474	valid-rmse:8.79379
[3400]	train-rmse:7.93728	valid-rmse:8.78306
[3500]	train-rmse:7.91258	valid-rmse:8.77673
[3600]	train-rmse:7.88885	valid-rmse:8.76971
[3700]	train-rmse:7.86443	valid-rmse:8.76973
[3800]	train-rmse:7.84141	valid-rmse:8.76900
[3900]	train-rmse:7.81934	valid-rmse:8.76143
[4000]	train-rmse:7.79565	valid-rmse:8.75647
[4100]	train-rmse:7.77428	valid-rmse:8.75083
[4200]	train-rmse:7.75317	valid-rmse:8.74758
[4300]	train-rmse:7.73325	valid-rmse:8.74617
[4400]	train-rmse:7.71291	valid-rmse:8.74639
[4500]	train-rmse:7.69298	valid-rmse:8.74199
[4600]	train-rmse:7.67567	valid-rmse:8.73616
[4700]	train-rmse:7.65697	valid-rmse:8.73342
[4800]	train-rmse:7.63766	valid-rmse:8.72931
[4900]	train-rmse:7.62179	valid-rmse:8.72737
[4999]	train-rmse:7.60456	valid-rmse:8.73030
3 0.005 0.6 0.3
0.6528095641379686
[0]	train-rmse:12.17

3 0.01 0.5 0.2
0.6451125425804145
[0]	train-rmse:12.15781	valid-rmse:12.40866
[100]	train-rmse:10.84071	valid-rmse:11.13573
[200]	train-rmse:10.13692	valid-rmse:10.45080
[300]	train-rmse:9.69824	valid-rmse:10.02435
[400]	train-rmse:9.39798	valid-rmse:9.75397
[500]	train-rmse:9.16558	valid-rmse:9.54823
[600]	train-rmse:8.96945	valid-rmse:9.39952
[700]	train-rmse:8.81149	valid-rmse:9.27120
[800]	train-rmse:8.67661	valid-rmse:9.18244
[900]	train-rmse:8.55779	valid-rmse:9.08766
[1000]	train-rmse:8.44867	valid-rmse:9.03738
[1100]	train-rmse:8.35745	valid-rmse:8.97456
[1200]	train-rmse:8.27931	valid-rmse:8.92527
[1300]	train-rmse:8.20423	valid-rmse:8.89533
[1400]	train-rmse:8.13711	valid-rmse:8.85120
[1500]	train-rmse:8.07075	valid-rmse:8.82356
[1600]	train-rmse:8.01396	valid-rmse:8.80536
[1700]	train-rmse:7.96403	valid-rmse:8.80540
[1800]	train-rmse:7.91327	valid-rmse:8.78309
[1900]	train-rmse:7.86828	valid-rmse:8.78625
[1922]	train-rmse:7.85698	valid-rmse:8.79121
3 0.01 0.5 0.25
0.64751632

[100]	train-rmse:10.76661	valid-rmse:11.05953
[200]	train-rmse:10.04796	valid-rmse:10.36693
[300]	train-rmse:9.63764	valid-rmse:9.98231
[400]	train-rmse:9.33128	valid-rmse:9.71446
[500]	train-rmse:9.10978	valid-rmse:9.52347
[600]	train-rmse:8.92732	valid-rmse:9.39275
[700]	train-rmse:8.77728	valid-rmse:9.27993
[800]	train-rmse:8.65096	valid-rmse:9.21314
[900]	train-rmse:8.53575	valid-rmse:9.13772
[1000]	train-rmse:8.42630	valid-rmse:9.07788
[1100]	train-rmse:8.33996	valid-rmse:9.02692
[1200]	train-rmse:8.25657	valid-rmse:8.97587
[1300]	train-rmse:8.18573	valid-rmse:8.93846
[1400]	train-rmse:8.11474	valid-rmse:8.89548
[1500]	train-rmse:8.05288	valid-rmse:8.86563
[1600]	train-rmse:7.99355	valid-rmse:8.84336
[1700]	train-rmse:7.94122	valid-rmse:8.83047
[1800]	train-rmse:7.89044	valid-rmse:8.80135
[1900]	train-rmse:7.84302	valid-rmse:8.79849
[2000]	train-rmse:7.79260	valid-rmse:8.78699
[2100]	train-rmse:7.74734	valid-rmse:8.77876
[2200]	train-rmse:7.71086	valid-rmse:8.77817
[2300]	train-rm

[4700]	train-rmse:8.60023	valid-rmse:9.19599
[4800]	train-rmse:8.57779	valid-rmse:9.18015
[4900]	train-rmse:8.55694	valid-rmse:9.16545
[4999]	train-rmse:8.53761	valid-rmse:9.15481
4 0.001 0.5 0.3
0.6333648171845128
[0]	train-rmse:12.18114	valid-rmse:12.43457
[100]	train-rmse:11.96638	valid-rmse:12.23198
[200]	train-rmse:11.76134	valid-rmse:12.04181
[300]	train-rmse:11.57374	valid-rmse:11.86295
[400]	train-rmse:11.40149	valid-rmse:11.69905
[500]	train-rmse:11.23565	valid-rmse:11.54567
[600]	train-rmse:11.08563	valid-rmse:11.40701
[700]	train-rmse:10.94088	valid-rmse:11.27212
[800]	train-rmse:10.80624	valid-rmse:11.14675
[900]	train-rmse:10.68373	valid-rmse:11.02947
[1000]	train-rmse:10.56276	valid-rmse:10.92277
[1100]	train-rmse:10.45052	valid-rmse:10.81941
[1200]	train-rmse:10.35075	valid-rmse:10.72861
[1300]	train-rmse:10.25312	valid-rmse:10.63850
[1400]	train-rmse:10.15862	valid-rmse:10.55147
[1500]	train-rmse:10.07153	valid-rmse:10.47473
[1600]	train-rmse:9.98695	valid-rmse:10.39656

[2000]	train-rmse:9.69392	valid-rmse:10.14363
[2100]	train-rmse:9.63371	valid-rmse:10.08956
[2200]	train-rmse:9.57649	valid-rmse:10.03893
[2300]	train-rmse:9.52237	valid-rmse:9.99109
[2400]	train-rmse:9.47025	valid-rmse:9.94581
[2500]	train-rmse:9.41887	valid-rmse:9.89909
[2600]	train-rmse:9.37044	valid-rmse:9.85959
[2700]	train-rmse:9.32351	valid-rmse:9.81931
[2800]	train-rmse:9.28033	valid-rmse:9.78348
[2900]	train-rmse:9.23769	valid-rmse:9.74427
[3000]	train-rmse:9.19649	valid-rmse:9.70979
[3100]	train-rmse:9.15908	valid-rmse:9.67802
[3200]	train-rmse:9.12279	valid-rmse:9.64740
[3300]	train-rmse:9.08905	valid-rmse:9.61816
[3400]	train-rmse:9.05445	valid-rmse:9.59089
[3500]	train-rmse:9.02120	valid-rmse:9.56397
[3600]	train-rmse:8.99023	valid-rmse:9.53983
[3700]	train-rmse:8.95965	valid-rmse:9.51634
[3800]	train-rmse:8.92875	valid-rmse:9.49127
[3900]	train-rmse:8.90027	valid-rmse:9.46825
[4000]	train-rmse:8.87264	valid-rmse:9.44978
[4100]	train-rmse:8.84407	valid-rmse:9.42735
[4200]	

[1400]	train-rmse:8.24862	valid-rmse:8.94479
[1500]	train-rmse:8.18067	valid-rmse:8.91148
[1600]	train-rmse:8.11802	valid-rmse:8.87438
[1700]	train-rmse:8.06443	valid-rmse:8.85496
[1800]	train-rmse:8.01197	valid-rmse:8.82435
[1900]	train-rmse:7.96072	valid-rmse:8.80748
[2000]	train-rmse:7.91065	valid-rmse:8.78981
[2100]	train-rmse:7.86345	valid-rmse:8.77544
[2200]	train-rmse:7.82752	valid-rmse:8.76001
[2300]	train-rmse:7.78835	valid-rmse:8.74361
[2400]	train-rmse:7.75239	valid-rmse:8.73041
[2500]	train-rmse:7.71434	valid-rmse:8.71431
[2600]	train-rmse:7.67998	valid-rmse:8.70787
[2700]	train-rmse:7.64377	valid-rmse:8.70414
[2800]	train-rmse:7.61113	valid-rmse:8.70378
[2900]	train-rmse:7.57787	valid-rmse:8.69920
[3000]	train-rmse:7.54848	valid-rmse:8.69760
[3100]	train-rmse:7.52022	valid-rmse:8.69090
[3200]	train-rmse:7.49527	valid-rmse:8.69072
[3213]	train-rmse:7.49236	valid-rmse:8.69099
4 0.005 0.5 0.25
0.6583339356572332
[0]	train-rmse:12.16632	valid-rmse:12.42083
[100]	train-rmse:11.

[2800]	train-rmse:7.63795	valid-rmse:8.79653
[2900]	train-rmse:7.60196	valid-rmse:8.78466
[3000]	train-rmse:7.57264	valid-rmse:8.77567
[3100]	train-rmse:7.54522	valid-rmse:8.76852
[3200]	train-rmse:7.51793	valid-rmse:8.76406
[3300]	train-rmse:7.49186	valid-rmse:8.75696
[3400]	train-rmse:7.46444	valid-rmse:8.75029
[3500]	train-rmse:7.43879	valid-rmse:8.74730
[3600]	train-rmse:7.41469	valid-rmse:8.75146
[3639]	train-rmse:7.40571	valid-rmse:8.75342
4 0.005 0.7 0.2
0.6568427543904024
[0]	train-rmse:12.16558	valid-rmse:12.42093
[100]	train-rmse:11.16684	valid-rmse:11.47854
[200]	train-rmse:10.47171	valid-rmse:10.84049
[300]	train-rmse:9.97352	valid-rmse:10.38204
[400]	train-rmse:9.61350	valid-rmse:10.06049
[500]	train-rmse:9.33291	valid-rmse:9.80767
[600]	train-rmse:9.12624	valid-rmse:9.63195
[700]	train-rmse:8.95276	valid-rmse:9.48882
[800]	train-rmse:8.81166	valid-rmse:9.38904
[900]	train-rmse:8.68224	valid-rmse:9.28872
[1000]	train-rmse:8.57582	valid-rmse:9.21932
[1100]	train-rmse:8.4736

[0]	train-rmse:12.15864	valid-rmse:12.41189
[100]	train-rmse:10.57426	valid-rmse:10.93989
[200]	train-rmse:9.72362	valid-rmse:10.17584
[300]	train-rmse:9.21113	valid-rmse:9.73112
[400]	train-rmse:8.88370	valid-rmse:9.45629
[500]	train-rmse:8.64476	valid-rmse:9.27345
[600]	train-rmse:8.45243	valid-rmse:9.15472
[700]	train-rmse:8.29493	valid-rmse:9.06241
[800]	train-rmse:8.16509	valid-rmse:8.99602
[900]	train-rmse:8.05226	valid-rmse:8.94393
[1000]	train-rmse:7.95489	valid-rmse:8.90733
[1100]	train-rmse:7.86348	valid-rmse:8.86596
[1200]	train-rmse:7.78433	valid-rmse:8.82585
[1300]	train-rmse:7.70929	valid-rmse:8.80950
[1400]	train-rmse:7.64052	valid-rmse:8.78529
[1500]	train-rmse:7.58019	valid-rmse:8.77336
[1600]	train-rmse:7.51459	valid-rmse:8.76204
[1700]	train-rmse:7.46096	valid-rmse:8.75135
[1800]	train-rmse:7.41073	valid-rmse:8.74003
[1900]	train-rmse:7.36645	valid-rmse:8.74175
[2000]	train-rmse:7.31756	valid-rmse:8.73275
[2081]	train-rmse:7.28146	valid-rmse:8.73496
4 0.01 0.7 0.2
0.

[1200]	train-rmse:9.81996	valid-rmse:10.28762
[1300]	train-rmse:9.70661	valid-rmse:10.18935
[1400]	train-rmse:9.60161	valid-rmse:10.09579
[1500]	train-rmse:9.50445	valid-rmse:10.01569
[1600]	train-rmse:9.41566	valid-rmse:9.94097
[1700]	train-rmse:9.33436	valid-rmse:9.87352
[1800]	train-rmse:9.25747	valid-rmse:9.80827
[1900]	train-rmse:9.18630	valid-rmse:9.74864
[2000]	train-rmse:9.11930	valid-rmse:9.69168
[2100]	train-rmse:9.05690	valid-rmse:9.64265
[2200]	train-rmse:9.00000	valid-rmse:9.59500
[2300]	train-rmse:8.94319	valid-rmse:9.54824
[2400]	train-rmse:8.88811	valid-rmse:9.50421
[2500]	train-rmse:8.83424	valid-rmse:9.46035
[2600]	train-rmse:8.78641	valid-rmse:9.42140
[2700]	train-rmse:8.73966	valid-rmse:9.38646
[2800]	train-rmse:8.69687	valid-rmse:9.35328
[2900]	train-rmse:8.65393	valid-rmse:9.32164
[3000]	train-rmse:8.61491	valid-rmse:9.29453
[3100]	train-rmse:8.57747	valid-rmse:9.26759
[3200]	train-rmse:8.54286	valid-rmse:9.24235
[3300]	train-rmse:8.50838	valid-rmse:9.21898
[3400]

[3800]	train-rmse:8.33915	valid-rmse:9.13121
[3900]	train-rmse:8.31312	valid-rmse:9.11462
[4000]	train-rmse:8.28619	valid-rmse:9.09647
[4100]	train-rmse:8.26043	valid-rmse:9.08076
[4200]	train-rmse:8.23458	valid-rmse:9.06393
[4300]	train-rmse:8.21008	valid-rmse:9.04900
[4400]	train-rmse:8.18567	valid-rmse:9.03533
[4500]	train-rmse:8.16369	valid-rmse:9.02267
[4600]	train-rmse:8.14239	valid-rmse:9.01010
[4700]	train-rmse:8.12167	valid-rmse:8.99906
[4800]	train-rmse:8.10110	valid-rmse:8.98770
[4900]	train-rmse:8.08082	valid-rmse:8.97653
[4999]	train-rmse:8.06196	valid-rmse:8.96664
5 0.001 0.6 0.3
0.6471795595220396
[0]	train-rmse:12.18062	valid-rmse:12.43416
[100]	train-rmse:11.91937	valid-rmse:12.19278
[200]	train-rmse:11.66317	valid-rmse:11.95832
[300]	train-rmse:11.43400	valid-rmse:11.74533
[400]	train-rmse:11.22040	valid-rmse:11.55102
[500]	train-rmse:11.01869	valid-rmse:11.37182
[600]	train-rmse:10.83958	valid-rmse:11.20685
[700]	train-rmse:10.66556	valid-rmse:11.05079
[800]	train-rm

[1200]	train-rmse:8.02184	valid-rmse:8.89993
[1300]	train-rmse:7.94904	valid-rmse:8.87145
[1400]	train-rmse:7.87585	valid-rmse:8.83120
[1500]	train-rmse:7.81171	valid-rmse:8.80986
[1600]	train-rmse:7.75321	valid-rmse:8.79170
[1700]	train-rmse:7.70183	valid-rmse:8.77625
[1800]	train-rmse:7.64882	valid-rmse:8.75817
[1900]	train-rmse:7.59982	valid-rmse:8.74985
[2000]	train-rmse:7.55203	valid-rmse:8.74186
[2100]	train-rmse:7.51051	valid-rmse:8.73194
[2200]	train-rmse:7.47368	valid-rmse:8.72013
[2300]	train-rmse:7.43536	valid-rmse:8.70963
[2400]	train-rmse:7.39771	valid-rmse:8.69793
[2500]	train-rmse:7.35901	valid-rmse:8.68303
[2600]	train-rmse:7.32575	valid-rmse:8.68057
[2669]	train-rmse:7.30267	valid-rmse:8.68254
5 0.005 0.5 0.2
0.6629299923422852
[0]	train-rmse:12.16300	valid-rmse:12.42026
[100]	train-rmse:10.96468	valid-rmse:11.32049
[200]	train-rmse:10.14653	valid-rmse:10.58227
[300]	train-rmse:9.58568	valid-rmse:10.07651
[400]	train-rmse:9.19892	valid-rmse:9.73717
[500]	train-rmse:8.9

[400]	train-rmse:9.15914	valid-rmse:9.74092
[500]	train-rmse:8.86124	valid-rmse:9.50957
[600]	train-rmse:8.64233	valid-rmse:9.34663
[700]	train-rmse:8.46542	valid-rmse:9.21920
[800]	train-rmse:8.32372	valid-rmse:9.13550
[900]	train-rmse:8.19957	valid-rmse:9.05976
[1000]	train-rmse:8.09338	valid-rmse:9.01357
[1100]	train-rmse:7.99838	valid-rmse:8.95969
[1200]	train-rmse:7.91578	valid-rmse:8.91614
[1300]	train-rmse:7.84476	valid-rmse:8.89161
[1400]	train-rmse:7.77445	valid-rmse:8.85942
[1500]	train-rmse:7.71027	valid-rmse:8.84211
[1600]	train-rmse:7.64866	valid-rmse:8.82374
[1700]	train-rmse:7.59384	valid-rmse:8.81269
[1800]	train-rmse:7.54281	valid-rmse:8.79458
[1900]	train-rmse:7.49880	valid-rmse:8.79341
[2000]	train-rmse:7.45496	valid-rmse:8.78723
[2100]	train-rmse:7.41140	valid-rmse:8.78221
[2200]	train-rmse:7.37615	valid-rmse:8.78191
[2266]	train-rmse:7.35197	valid-rmse:8.78282
5 0.005 0.7 0.25
0.6574966881897868
[0]	train-rmse:12.16263	valid-rmse:12.41816
[100]	train-rmse:10.85642	

[1000]	train-rmse:7.43118	valid-rmse:8.87334
[1100]	train-rmse:7.35169	valid-rmse:8.86181
[1200]	train-rmse:7.27506	valid-rmse:8.84708
[1300]	train-rmse:7.20740	valid-rmse:8.84694
[1337]	train-rmse:7.18142	valid-rmse:8.84370
5 0.01 0.7 0.3
0.6591797527022162


In [88]:
index

[5, 0.01, 0.6, 0.3]

In [92]:
pred=m2.predict(dtest)
seq_id = []
tm = []
for i in range(len(pred)):
    seq_id.append(31390+i)
    tm.append(pred[i])
data = {
    "seq_id":seq_id,
    "tm":tm
}
df = pd.DataFrame(data)
df.to_csv("submission.csv")

In [None]:
#### Generating the design matrix of training data and validation data
#### The validation data are classified based on the group number.
group_num = max(cdata['group'])
valid_size = 0
valid_index=np.array(random.sample(range(0, group_num+1), int(group_num*0.3)))
train_index = np.array(range(0,group_num+1))
train_index = np.delete(train_index,valid_index)
X1_train = pd.DataFrame()
X2_train = pd.DataFrame()
X1_valid = pd.DataFrame()
X2_valid = pd.DataFrame()
train_y = np.array([])
valid_y = np.array([])
y = np.array([])
for i in range(0,group_num+1):
    group_data = cdata[cdata['group']==i]
    g_seq = group_data['protein_sequence']
    g_seq = np.array(g_seq)
    g_y = group_data['tm']
    g_y = np.array(g_y)
    wt_index = get_wt_sup(g_seq)
    g_y = g_y - g_y[wt_index]
    y = np.append(y, g_y)
    wt = g_seq[wt_index]
    if (i in train_index):
        for k in g_seq:
            freq = find_mutation(wt,k)
            df = pd.DataFrame(freq,index=[1])
            X1_train = pd.concat([X1_train,df],ignore_index = True)
            x = find_mut_pst(wt, k)
            freq2 = get_nba(x, wt)
            dff = pd.DataFrame(freq2, index=[1])
            X2_train = pd.concat([X2_train,dff], ignore_index = True)
        train_y = np.append(train_y, g_y)
    else:
        for k in g_seq:
            freq = find_mutation(wt,k)
            df = pd.DataFrame(freq,index=[1])
            X1_valid = pd.concat([X1_valid,df],ignore_index = True)
            x = find_mut_pst(wt, k)
            freq2 = get_nba(x, wt)
            dff = pd.DataFrame(freq2, index=[1])
            X2_valid = pd.concat([X2_valid,dff], ignore_index = True)
        valid_y = np.append(valid_y, g_y)
X1_train = X1_train.fillna(0)
X2_train = X2_train.fillna(0)
X1_valid = X1_valid.fillna(0)
X2_valid = X2_valid.fillna(0)
train_X = pd.concat([X1_train,X2_train],axis = 1)
valid_X = pd.concat([X1_valid,X2_valid],axis = 1)
train_y = pd.DataFrame(train_y)
valid_y = pd.DataFrame(valid_y)

In [None]:
####Ensuring design matrix(features) of training and test data have the same size.
check_1 = pd.concat([train_X, valid_X],ignore_index = True).fillna(0)
check_2 = pd.concat([check_1, test_design], ignore_index = True).fillna(0)
dim1 = train_X.shape[0]
dim2 = valid_X.shape[0]
dim3 = check_2.shape[0]
train_X = check_2.iloc[0:dim1]
valid_X = check_2.iloc[dim1:dim1+dim2]
test_design = check_2.iloc[dim1+dim2:dim3]
train_X.to_csv('train_X.csv')
train_y.to_csv('train_y.csv')

In [None]:
# Model (2) This model has a validation set that only contains enzymes 
# with completely different wildtypes from wildtypes of the training set
xgb_parms = { 
    'max_depth':4, 
    'learning_rate':0.001, 
    'subsample':0.5,
    'colsample_bytree':0.2, 
    'eval_metric':'rmse',
    'objective':'reg:squarederror',
    'random_state':123,
    "reg_lambda":0.1
}

model = xgb
dtrain = xgb.DMatrix(data=train_X, label=train_y)
dvalid = xgb.DMatrix(data=valid_X, label=valid_y)
dtest = xgb.DMatrix(data=test_design)
m3 = xgb.train(xgb_parms, dtrain=dtrain,evals=[(dtrain,"train"),(dvalid, "valid")],
                 num_boost_round=10000,
                 early_stopping_rounds = 100,
                 verbose_eval = 100)

In [None]:
####Neural network (A neural network Model)
####This model did not perform well.

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
tX1 = torch.tensor(train_X.values)
tdm = torch.tensor(test_X.values)
ty = torch.tensor(y)
X_train, X_test, y_train, y_test = train_test_split(tX1, ty, test_size=0.1, random_state=123)


n_input, n_hidden, n_out, batch_size, learning_rate = 406, 83, 1, 100, 0.1
model = nn.Sequential(nn.Linear(n_input, n_hidden),
                      nn.ReLU(),
                      nn.Linear(n_hidden, 64),
                      nn.ReLU(),
                      nn.Linear(64, 16),
                      nn.ReLU(),
                      nn.Linear(16,4),
                      nn.ReLU(),
                      nn.Linear(4, n_out),
                      nn.Sigmoid())

print(model)
model.double()
loss_function = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
losses = []
mlos = []
for epoch in range(1000):
    print(epoch)
    pred_y = model(X_train)
    pred_yy = model(X_test).float()
    loss = loss_function(pred_y, y_train)
    mloss = loss_function(pred_yy, y_test)
    losses.append(loss.item())
    mlos.append(mloss.item())
    model.zero_grad()
    loss.backward()

    optimizer.step()