In [1]:
import pandas as pd
import numpy as np
from copy import deepcopy
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score, f1_score
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import ExtraTreesClassifier
import lightgbm as lgb
import xgboost as xgb
from catboost import Pool, CatBoostClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold
from sklearn.naive_bayes import GaussianNB 
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression

In [2]:
from IPython.core.display import clear_output
import sys
def change_output(x):
    clear_output()
    sys.stdout.write(x)
    sys.stdout.flush()


In [3]:
def load_data(filename, sep=","):
    '''
        Функция для загрузки датасета без инициализации класса. 
        
        Возвращает:
         data - полный DataFrame
         fetures - список признаков для обучения/предсказания
         x - датасет для тренировки/предсказания в формате np.array
         y - целевой признак, если он есть. В случае с тестовым датасетом возвращает None
    '''
    
    data = pd.read_table(filename, sep=sep).dropna() #.fillna(-999999)    
    
    data = data[~(data['chainlen']>1000)]

    if ('DSSR' in data.columns):
        data.drop('DSSR', axis=1, inplace=True)    
    
    features = list(deepcopy(data.columns))
    [features.remove(column) for column in ['Id','index', 'pdb_chain', 'mg'] if column in data.columns];
    x_test = np.array(data[features])
    
    try:
        y_test = np.array(data['mg'])
    except: 
        y_test = None
    change_output('Data loaded')
    return {'data':data, 'features':features, 'x':x_test, 'y':y_test}

In [4]:
test = load_data("test.csv")

Data loaded

In [5]:
change_output('Loading data...')        
data = pd.read_table("train.csv")
change_output('Data processing...')  
if ('DSSR' in data.columns):
    data.drop('DSSR', axis=1, inplace=True)            
data = data.dropna()
y = deepcopy(np.array(np.matrix(data['mg']).flatten().tolist()[0])) 
data_numpy = np.matrix(data)
features = list(data.columns)
features.remove('pdb_chain')
features.remove('mg')
groups = np.array(data['pdb_chain'])
X = np.array(data[features])
features.append('mg')
change_output('Done')  

Done

In [6]:
fi = np.array([8.37719494e-04, 5.95851621e-03, 9.85790946e-03, 6.58311779e-04,
       4.98447384e-03, 5.34551361e-03, 4.41471291e-03, 4.50505970e-03,
       4.56330928e-03, 4.70680305e-03, 4.30524333e-03, 4.91976436e-03,
       3.97949899e-03, 4.37511247e-03, 4.07568193e-03, 5.17924926e-03,
       4.92749131e-03, 4.81285050e-03, 4.85455992e-03, 4.80210650e-03,
       5.44854764e-03, 4.88728381e-03, 4.26637682e-03, 4.53950962e-03,
       4.67869705e-03, 4.68349264e-03, 4.30915245e-03, 4.56332910e-03,
       4.13989792e-03, 4.64412344e-03, 5.55468678e-03, 4.52450085e-03,
       4.29287607e-03, 4.28285976e-03, 4.46584093e-03, 4.32664699e-03,
       4.51653001e-03, 4.21761114e-03, 4.27743202e-03, 4.64259729e-03,
       5.31728809e-03, 4.87377702e-03, 4.75601150e-03, 5.15440112e-03,
       5.53981827e-03, 5.23599622e-03, 5.26360711e-03, 3.91604110e-03,
       4.06798616e-03, 4.08266886e-03, 4.37467342e-03, 4.21931663e-03,
       4.46754553e-03, 4.24214126e-03, 4.55833046e-03, 5.33798445e-03,
       4.54425022e-03, 4.44415310e-03, 4.30843803e-03, 4.91772122e-03,
       4.17315336e-03, 4.95835987e-03, 3.80753214e-03, 4.39400751e-03,
       4.36322021e-03, 5.53643574e-03, 5.03440048e-03, 5.09679402e-03,
       5.20133254e-03, 5.62837785e-03, 5.36533506e-03, 5.14070197e-03,
       4.04494055e-03, 4.16946674e-03, 4.26667193e-03, 4.50308895e-03,
       3.96861546e-03, 4.59122884e-03, 4.20598830e-03, 4.69359791e-03,
       5.19369139e-03, 4.63584549e-03, 4.15966874e-03, 4.31123342e-03,
       4.58478757e-03, 4.35197617e-03, 4.79957404e-03, 4.19420419e-03,
       4.47094371e-03, 4.19985206e-03, 5.11594398e-03, 4.80625449e-03,
       4.79665905e-03, 4.88522230e-03, 5.61647547e-03, 5.47869701e-03,
       5.08340733e-03, 4.01876782e-03, 4.24213902e-03, 4.19247561e-03,
       4.34883928e-03, 4.39333467e-03, 4.37419072e-03, 4.04229371e-03,
       4.60332310e-03, 5.41950686e-03, 4.68108845e-03, 4.24900379e-03,
       4.34313570e-03, 4.47052198e-03, 4.55504200e-03, 4.67954354e-03,
       4.27798809e-03, 4.29317089e-03, 4.19199390e-03, 5.11875450e-03,
       4.90476795e-03, 4.86520413e-03, 5.01986410e-03, 5.43579538e-03,
       4.92259985e-03, 4.90694162e-03, 4.06197782e-03, 4.00283450e-03,
       4.25164962e-03, 4.42927435e-03, 4.16515525e-03, 4.46049020e-03,
       4.23435969e-03, 5.79104448e-04, 5.27844534e-04, 5.49401462e-04,
       5.08870959e-04, 5.43228006e-04, 5.26209835e-04, 6.33428325e-04,
       5.52223272e-04, 5.70007605e-04, 5.71919070e-04, 7.97155912e-04,
       4.84353490e-04, 6.61278601e-04, 7.31656713e-04, 9.63742407e-04,
       5.07483643e-04, 5.95337295e-04, 5.04932229e-04, 5.64716331e-04,
       4.53721215e-04, 5.69724214e-05, 9.99815893e-05, 1.83130567e-06,
       2.56696431e-05, 1.29364181e-04, 5.35290363e-05, 1.05337155e-04,
       2.07823015e-04, 3.66725195e-05, 1.16943202e-04, 3.79581850e-04,
       1.79645464e-05, 7.33507683e-07, 0.00000000e+00, 8.67001839e-06,
       1.71346338e-04, 0.00000000e+00, 9.68122952e-05, 6.73772053e-04,
       5.34915874e-04, 1.41833425e-04, 9.91809531e-05, 1.56198900e-04,
       4.41293854e-04, 1.09518792e-04, 6.45897969e-06, 3.48770882e-06,
       2.96577370e-04, 6.21098595e-05, 2.97335704e-04, 1.87282539e-04,
       5.87601197e-04, 2.10357337e-04, 1.26120027e-04, 3.66957905e-04,
       3.30279986e-04, 2.52522780e-04, 1.39314570e-04, 2.56529143e-04,
       1.49588667e-04, 4.00874365e-04, 6.13512153e-05, 1.62177644e-04,
       3.93548478e-06, 2.88900451e-05, 8.78022806e-05, 9.99540421e-05,
       1.48591171e-04, 1.80715757e-04, 6.26997490e-05, 1.79282217e-04,
       3.32733309e-04, 1.29912914e-05, 8.39516989e-06, 0.00000000e+00,
       1.05977155e-05, 1.93168155e-04, 0.00000000e+00, 7.91991931e-05,
       5.28009223e-04, 4.60317066e-04, 1.17261591e-04, 7.62299952e-05,
       2.98205775e-04, 4.09607879e-04, 4.94006310e-05, 2.41866080e-05,
       2.80139315e-06, 3.36784591e-04, 1.39955896e-04, 3.33939460e-04,
       2.18863761e-04, 5.88060757e-04, 2.90341727e-04, 2.14630452e-04,
       3.01876443e-04, 4.83256075e-04, 2.69988276e-04, 1.63472408e-04,
       2.59334517e-04, 1.63325507e-04, 3.54313832e-04, 4.89715332e-05,
       9.70264976e-05, 0.00000000e+00, 7.71810016e-06, 9.54130181e-05,
       9.89999999e-05, 7.92013869e-05, 1.60029549e-04, 1.04166914e-04,
       2.06765875e-04, 3.43121145e-04, 6.88432918e-06, 2.64602395e-05,
       0.00000000e+00, 4.83238088e-06, 1.41323649e-04, 0.00000000e+00,
       9.30730663e-05, 6.04684753e-04, 4.31345815e-04, 8.51820623e-05,
       6.26713544e-05, 1.18531357e-04, 4.02647854e-04, 4.46964015e-05,
       1.74753700e-05, 1.88540564e-05, 3.26131650e-04, 8.23416892e-05,
       3.43390024e-04, 1.91686590e-04, 6.04850045e-04, 2.75271081e-04,
       1.81471923e-04, 2.17965117e-04, 3.93605820e-04, 2.36572022e-04,
       1.70048983e-04, 3.43554932e-04, 1.49253958e-04, 3.58772052e-04,
       2.63894716e-05, 5.70927849e-05, 0.00000000e+00, 2.89809231e-05,
       6.47967885e-05, 4.11093221e-05, 2.14801414e-04, 2.40949632e-04,
       2.02792758e-05, 2.49642728e-04, 3.21244484e-04, 7.32023721e-06,
       5.66518255e-05, 0.00000000e+00, 2.88416959e-05, 1.56573049e-04,
       0.00000000e+00, 5.56516137e-05, 5.73642329e-04, 4.68682552e-04,
       7.86183739e-05, 4.97846765e-05, 1.23830568e-04, 3.54865972e-04,
       1.50719052e-04, 2.97068505e-05, 1.72614903e-04, 3.03014355e-04,
       2.43250969e-05, 2.25217687e-04, 1.91038080e-04, 5.52706982e-04,
       1.37675586e-04, 8.45958371e-05, 2.23918115e-04, 4.25967593e-04,
       2.27451305e-04, 1.79095533e-04, 3.42328660e-04, 1.17746551e-04,
       3.06990016e-04, 4.28498482e-05, 1.18261930e-04, 0.00000000e+00,
       4.37198954e-05, 6.08905326e-05, 7.51961905e-05, 1.12770761e-04,
       1.60020173e-04, 6.56277583e-05, 2.66325561e-04, 4.01114487e-04,
       2.99110348e-05, 1.85594664e-05, 0.00000000e+00, 1.54300968e-05,
       1.47432561e-04, 0.00000000e+00, 5.21852540e-05, 5.46555011e-04,
       5.10580254e-04, 6.57748306e-05, 1.27636661e-04, 1.01013032e-04,
       5.11901653e-04, 2.31260670e-04, 4.54754153e-06, 2.63174823e-04,
       2.52089795e-04, 6.17674758e-05, 1.81181880e-04, 1.87284511e-04,
       5.97785126e-04, 1.76717586e-04, 3.32998427e-05, 2.00395315e-04,
       4.75653510e-04, 2.43317803e-04, 1.81846200e-04, 3.72520671e-04,
       1.67174479e-04, 2.40488722e-04, 4.35844845e-04, 5.35013151e-04,
       2.40240392e-04, 5.08656837e-04, 5.44452082e-04, 3.09097311e-03,
       1.65267010e-03, 2.94408973e-02, 2.76366640e-02, 1.47814550e-02,
       2.10565143e-02, 2.10397296e-02, 2.31742452e-02, 1.08801911e-02,
       2.32715775e-02, 7.66467239e-03, 8.83440352e-03, 1.58065505e-02,
       1.42774157e-02, 7.64515305e-03, 1.21772694e-02, 2.84605206e-02,
       6.64886758e-03, 1.09290814e-02, 1.75631239e-02, 7.18848712e-03,
       2.01474919e-02, 2.10993687e-02, 0.00000000e+00])
fi

array([8.37719494e-04, 5.95851621e-03, 9.85790946e-03, 6.58311779e-04,
       4.98447384e-03, 5.34551361e-03, 4.41471291e-03, 4.50505970e-03,
       4.56330928e-03, 4.70680305e-03, 4.30524333e-03, 4.91976436e-03,
       3.97949899e-03, 4.37511247e-03, 4.07568193e-03, 5.17924926e-03,
       4.92749131e-03, 4.81285050e-03, 4.85455992e-03, 4.80210650e-03,
       5.44854764e-03, 4.88728381e-03, 4.26637682e-03, 4.53950962e-03,
       4.67869705e-03, 4.68349264e-03, 4.30915245e-03, 4.56332910e-03,
       4.13989792e-03, 4.64412344e-03, 5.55468678e-03, 4.52450085e-03,
       4.29287607e-03, 4.28285976e-03, 4.46584093e-03, 4.32664699e-03,
       4.51653001e-03, 4.21761114e-03, 4.27743202e-03, 4.64259729e-03,
       5.31728809e-03, 4.87377702e-03, 4.75601150e-03, 5.15440112e-03,
       5.53981827e-03, 5.23599622e-03, 5.26360711e-03, 3.91604110e-03,
       4.06798616e-03, 4.08266886e-03, 4.37467342e-03, 4.21931663e-03,
       4.46754553e-03, 4.24214126e-03, 4.55833046e-03, 5.33798445e-03,
      

In [7]:
fi_sorted = np.copy(fi)
fi_sorted.sort()

In [9]:
from mlxtend.classifier import StackingClassifier

In [12]:
clf1 = lgb.LGBMClassifier(scale_pos_weight=8.255102040816327, n_estimators=100, random_state=17)
clf2 = lgb.LGBMClassifier(scale_pos_weight=4, n_estimators=100, random_state=17)
clf3 = lgb.LGBMClassifier(scale_pos_weight=12, n_estimators=100, random_state=17)
clf4 = KNeighborsClassifier(n_neighbors=3)
clf5 = LogisticRegression(random_state=18)
lr = LogisticRegression(random_state=18)
sclf = StackingClassifier(classifiers=[clf1, clf2, clf3, clf4, clf5], 
                          meta_classifier=lr)

In [13]:
%%time
cross_val_score(sclf, X, y, scoring="f1")

  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:
  if diff:


CPU times: user 2h 18min 31s, sys: 5.16 s, total: 2h 18min 36s
Wall time: 2h 15min 11s


array([0.19726487, 0.13368324, 0.16232919])

In [14]:
print(f1_score(sclf.predict(X_test[:, fi > fi_sorted[100]]), y_test))

  if diff:
  if diff:
  if diff:


0.5701763123817113


In [16]:
sclf

StackingClassifier(average_probas=False,
          classifiers=[LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
        learning_rate=0.1, max_depth=-1, min_child_samples=20,
        min_child_weight=0.001, min_split_gain=0.0, n_estimators=100,
        n_jobs=-1, num_leaves=31, objective=None, random_seed=18,
        ra...nalty='l2', random_state=18, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)],
          meta_classifier=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False),
          refit=True, store_train_meta_features=False,
          use_features_in_secondary=False, use_probas=False, verbose=0)

In [18]:
prediction = sclf.predict(test["x"][:, fi > fi_sorted[100]]).astype(int)

  if diff:
  if diff:
  if diff:


In [19]:
np.unique(prediction, return_counts=True)

(array([0, 1]), array([3905,  140]))

In [20]:
df_sub = pd.read_csv("sample_submission.csv")
df_sub["mg"] = prediction
df_sub.to_csv("my_sub.csv", index=False)

In [21]:
df_sub

Unnamed: 0,Id,mg
0,0,0
1,1,0
2,2,0
3,3,0
4,4,0
5,5,0
6,6,0
7,7,0
8,8,0
9,9,0
