In [2]:
import numpy as np
import sys
import polars as pl
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from imblearn.under_sampling import NearMiss
from sklearn import metrics
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix, precision_recall_curve
from sklearn import preprocessing
from sklearn.model_selection import cross_val_score

In [20]:
data = (pl.read_csv('../results/results/embeddings/rag_ember-v1_gpt4_embeddings.csv', # has_header=False,
                    new_columns=['condition_name', 'drug_name', 'affect'])
        .filter(pl.col('affect') != 0)
        .with_columns(
            pl.when(pl.col('affect') == 1).then(1).otherwise(0).alias("affect")
            )
        )
print(data.shape)
data.head()

(1476, 2051)


condition_name,drug_name,affect,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,…,987_duplicated_0,988_duplicated_0,989_duplicated_0,990_duplicated_0,991_duplicated_0,992_duplicated_0,993_duplicated_0,994_duplicated_0,995_duplicated_0,996_duplicated_0,997_duplicated_0,998_duplicated_0,999_duplicated_0,1000_duplicated_0,1001_duplicated_0,1002_duplicated_0,1003_duplicated_0,1004_duplicated_0,1005_duplicated_0,1006_duplicated_0,1007_duplicated_0,1008_duplicated_0,1009_duplicated_0,1010_duplicated_0,1011_duplicated_0,1012_duplicated_0,1013_duplicated_0,1014_duplicated_0,1015_duplicated_0,1016_duplicated_0,1017_duplicated_0,1018_duplicated_0,1019_duplicated_0,1020_duplicated_0,1021_duplicated_0,1022_duplicated_0,1023_duplicated_0
str,str,i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""acute kidney i…","""6-aminocaproic…",0,-0.113791,0.5761229,-0.057424,0.422981,-1.039098,-0.316022,0.475713,0.384223,0.5568268,0.8824807,-0.369988,-0.578911,-0.141257,0.029517,-0.631581,-0.261116,0.013399,-0.50964,-0.322155,-0.16976,0.739859,-0.040078,0.254749,0.067516,0.2803745,1.0264088,0.255942,-0.605479,1.5246605,0.5719897,-0.041438,0.3692937,0.716658,0.339189,…,1.1658257,0.059245,0.416834,-0.527663,0.247297,0.5109924,0.199492,-0.105386,0.3376889,0.419384,-0.238846,0.6458211,-0.75336,-0.320593,0.215728,0.109278,0.3314094,0.594011,-0.255532,0.249085,0.3002335,0.042368,-0.217641,0.045214,0.365678,-0.577791,0.7765181,0.338307,-0.757739,-0.324013,-0.650646,0.5834112,0.9333886,-0.554994,-0.122884,-0.441791,-0.550147
"""acute kidney i…","""3-iodobenzylgu…",0,0.167238,0.075416,0.351527,0.191419,-0.820671,0.10259,0.172632,0.143705,-0.002329,0.561895,0.593211,-0.399002,0.601526,-0.297195,0.0374561,0.544703,-0.043588,-0.268589,-0.042387,0.312905,0.830163,0.6969075,-0.537826,0.5075295,-0.235583,0.9864152,-0.089572,-0.11895,1.3799655,0.4286258,0.193598,-0.245762,1.162368,-0.254496,…,0.416473,-0.104335,0.16713,-0.505979,0.237364,0.7699128,0.140969,0.6449821,0.325432,-0.076962,-0.389527,0.4648688,-0.833855,-0.344343,-0.120699,0.22262,0.328187,0.461831,0.160315,0.4900108,-0.339558,0.219848,0.276337,0.008242,0.7823667,-0.884464,0.487854,0.1376159,-0.255247,-0.350863,-0.000354,0.6476129,1.1781952,-0.479792,0.089264,-0.3088,-0.232126
"""acute kidney i…","""abacavir""",0,-0.950485,0.584202,-0.062872,0.225852,-0.368249,0.009063,-0.46825,-0.410719,0.203988,0.6683506,-0.391967,-0.558072,0.473859,0.03328,-0.638843,0.120856,0.133404,-0.539722,-0.148631,0.233409,0.7949639,0.7236491,-0.629767,-0.268083,0.158468,0.8079171,-0.140388,-0.691627,1.2390614,0.243923,-0.01037,-0.41288,0.796795,-0.361516,…,1.0070913,0.3670554,0.996402,-0.82922,0.165006,0.8591236,0.737249,-0.549624,0.389619,0.2588,-0.188909,0.154625,-0.698787,0.025228,-0.24539,0.001009,0.4528022,0.136914,0.01325,0.5839833,0.2742374,0.162868,0.165932,0.039104,0.609531,-0.882788,0.4865759,-0.054819,-0.29218,-0.347759,-0.542438,0.36695,1.1608756,-0.724895,0.161244,0.193736,-0.699643
"""acute kidney i…","""abatacept""",0,-0.644795,0.6654043,0.018576,0.2889785,-0.504876,0.31782,-0.222896,-0.263948,0.168481,0.8802953,-0.495485,-0.44851,0.189278,0.21243,-0.803692,-0.070876,-0.214516,-0.172695,-0.198993,0.274046,0.74019,0.6826604,-0.260914,-0.370646,0.1252962,0.6904323,-0.148557,-0.643003,1.4017996,0.586186,0.006193,-0.145872,0.693184,0.151168,…,1.4076331,0.6122786,0.432108,-0.451074,0.210564,0.47348,0.457086,0.105763,0.504564,0.44329,0.260639,0.551414,-0.624523,-0.286071,-0.077608,0.14784,0.37661,0.4917209,0.2599635,0.6102624,-0.012514,0.223976,-0.108235,0.170128,0.7445779,-0.971231,0.6271078,0.109472,-0.663006,-0.335138,-0.487633,0.2993779,0.771351,-0.35752,0.140006,-0.479819,-0.665837
"""acute kidney i…","""abemaciclib""",0,-0.321315,0.095288,0.031969,0.285429,-0.090508,0.7392465,-0.014736,-0.58783,0.496446,1.13638,-0.227364,-0.665901,0.211377,-1.049659,-0.298218,-0.298889,0.109319,-0.482085,-0.129278,0.3819297,1.1451916,0.430311,0.3275,0.217305,-0.700397,1.0643435,0.263392,-0.891269,0.803679,1.1250405,-0.313871,0.02912,0.878989,-0.165224,…,0.407151,-0.164216,0.8426841,-0.565621,-0.030819,0.663494,0.863509,0.121952,0.3325903,0.98039,-0.177101,-0.161349,0.2599125,0.138193,-0.172958,0.052935,0.2937497,0.239732,-0.656719,0.641902,0.109112,0.195424,-0.203523,0.5096087,0.293409,-0.965997,0.195351,0.248844,-0.513589,0.472769,-0.182507,0.3904834,0.4845075,-0.311689,-0.368671,0.032888,-0.406301


In [21]:
data['condition_name'].value_counts().sort(by = 'condition_name')

condition_name,count
str,u32
"""acute kidney i…",425
"""acute liver in…",512
"""acute myocardi…",313
"""gi bleed""",226


In [22]:
embed_col_names = [x for x in data.columns if x not in ['condition_name', 'drug_name', 'affect', 'outcome', 'status', '']]
support_cols = [x for x in embed_col_names if 'duplicated' not in x]
refute_cols = [x for x in embed_col_names if 'duplicated' in x]

In [23]:
data['affect'].value_counts()

affect,count
i32,u32
1,139
0,1337


In [24]:
condition = 'acute kidney injury'

In [26]:
def make_classifier(model):
    if model == 'logistic':
        clf = LogisticRegression(random_state=0)
    if model == 'rf':
        clf = RandomForestClassifier(random_state=0)
    if model == 'knn':
        clf = KNeighborsClassifier(n_neighbors=5, random_state=0)
    return clf

## Cross Validation

In [None]:
def get_cross_vals(X,y, ):


In [25]:
print(cross_val_score(clf, X_actual, y_actual, cv=5, scoring = 'roc_auc'))
print(cross_val_score(clf, X_actual, y_actual, cv=5, scoring = 'average_precision'))

NameError: name 'clf' is not defined

## Undersampling

In [85]:
random.seed(100)
y_actual = data['affect'].to_numpy()
X_actual = data[embed_col_names].to_numpy()
undersample = NearMiss(version=1, n_neighbors=5)
X,y = undersample.fit_resample(X_actual, y_actual)

In [86]:
# train on other data - train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=100, stratify=y)

In [87]:
clf = LogisticRegression(random_state=0)
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
predict_probs = clf.predict_proba(X_test)

In [89]:
precision, recall, thresholds = precision_recall_curve(y_test, predict_probs[:,1])
f1_scores = 2*recall*precision/(recall+precision)
print('Best threshold: ', thresholds[np.argmax(f1_scores)])
print('Best F1-Score: ', np.max(f1_scores))

Best threshold:  0.8467536368462844
Best F1-Score:  0.9117647058823528


In [90]:
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

[[31  4]
 [ 3 31]]
              precision    recall  f1-score   support

           0       0.91      0.89      0.90        35
           1       0.89      0.91      0.90        34

    accuracy                           0.90        69
   macro avg       0.90      0.90      0.90        69
weighted avg       0.90      0.90      0.90        69



In [97]:
print(cross_val_score(clf, X_actual, y_actual, cv=5, scoring = 'roc_auc'))
print(cross_val_score(clf, X_actual, y_actual, cv=5, scoring = 'average_precision'))

[0.88736871 0.82324489 0.76457999 0.79454254 0.88242376]
[0.53185648 0.3257435  0.30931476 0.42108849 0.56033103]


In [63]:
predictions = clf.predict(X_actual)
print(confusion_matrix(y_actual, predictions))
print(f1_score(y_actual, predictions))

[[537 800]
 [  3 135]]
0.2516309412861137


In [91]:

X_train, X_test, y_train, y_test = train_test_split(X_actual, y_actual, test_size=0.3, random_state=100, stratify=y_actual)

In [92]:
clf = LogisticRegression(random_state=0,    max_iter=1000) #, class_weight='balanced')
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
predict_proba = clf.predict_proba(X_test)
print(confusion_matrix(y_test, predictions))
print(f1_score(y_test, predictions))
print(accuracy_score(y_test, predictions))

[[385  17]
 [ 24  17]]
0.4533333333333333
0.90744920993228


In [72]:
precision, recall, thresholds = precision_recall_curve(y_test, predict_proba[:,1])
f1_scores = 2*recall*precision/(recall+precision)
print('Best threshold: ', thresholds[np.argmax(f1_scores)])
print('Best F1-Score: ', np.max(f1_scores))

Best threshold:  0.4159840383391849
Best F1-Score:  0.5227272727272727
