In [113]:
import polars as pl
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc


In [92]:
file = '../results/basic_embeddings.csv'
mode = file.split('_')[0].split('/')[-1]

In [47]:
data = (pl.read_csv('../results/basic_embeddings.csv')
        .with_columns(
            pl.when(pl.col('affect') == 1).then(pl.lit('Positive'))
                .when(pl.col('affect') == 0).then(pl.lit('Negative'))
                .when(pl.col('affect').is_null()).then(pl.lit('Non-Positive'))
                .alias('status'),
            pl.when(pl.col('affect') == 1).then(pl.lit(1))
                .when(pl.col('affect') == -1).then(pl.lit(0))
                .alias('outcome')
        )
        )
print(data[['', 'condition_name', 'drug_name', 'affect', '0']].head())

shape: (5, 5)
┌─────┬─────────────────────┬────────────────┬────────┬───────────┐
│     ┆ condition_name      ┆ drug_name      ┆ affect ┆ 0         │
│ --- ┆ ---                 ┆ ---            ┆ ---    ┆ ---       │
│ i64 ┆ str                 ┆ str            ┆ f64    ┆ f64       │
╞═════╪═════════════════════╪════════════════╪════════╪═══════════╡
│ 0   ┆ acute kidney injury ┆ morphine       ┆ -1.0   ┆ -0.329442 │
│ 1   ┆ acute kidney injury ┆ delafloxacin   ┆ -1.0   ┆ -0.631743 │
│ 2   ┆ acute kidney injury ┆ aspirin        ┆ -1.0   ┆ -0.657761 │
│ 3   ┆ acute kidney injury ┆ levoleucovorin ┆ -1.0   ┆ -0.11017  │
│ 4   ┆ acute kidney injury ┆ afatinib       ┆ -1.0   ┆ -0.602554 │
└─────┴─────────────────────┴────────────────┴────────┴───────────┘


In [48]:
data.group_by([pl.col('condition_name'),
              pl.col('outcome')]).len()

condition_name,outcome,len
str,i32,u32
"""acute kidney i…",1,10
"""acute liver in…",0,10
"""acute kidney i…",0,10
"""acute myocardi…",0,10
"""acute myocardi…",1,10
"""acute liver in…",1,10
"""gi bleed""",1,10
"""gi bleed""",0,10


In [23]:
embed_col_names = [x for x in data.columns if x not in ['condition_name', 'drug_name', 'affect', 'outcome', 'status', '']]

In [89]:
def performance(X_test, y_test, logreg, predictions):

    accuracy = accuracy_score(y_test, predictions)
    precision = precision_score(y_test, predictions)
    recall = recall_score(y_test, predictions)
    f1 = f1_score(y_test, predictions)

    class_probabilities = logreg.predict_proba(X_test)
    preds = class_probabilities[:, 1]

    fpr, tpr, threshold = roc_curve(y_test, preds)
    roc_auc = auc(fpr, tpr)
    return [accuracy, precision, recall, f1, roc_auc]

In [114]:
def make_classifier(model):
    if model == 'logistic':
        clf = LogisticRegression(random_state=0)
    if model == 'rf':
        clf = RandomForestClassifier(random_state=0)
    if model == 'knn':
        clf = KNeighborsClassifier(n_neighbors=3)
    return clf

In [115]:
results = []
for model in ['logistic', 'rf', 'knn']:
    for condition in list(set(data['condition_name'])):
        hold_out = data.filter(pl.col('condition_name') == condition)
        
        # train on other data - train test split
        subset_data = data.filter(pl.col('condition_name') != condition)
        y_actual = subset_data['outcome'].to_numpy()
        X = subset_data[embed_col_names].to_numpy()
        X_train, X_test, y_train, y_test = train_test_split(X, y_actual, test_size=0.25, random_state=100)
        clf = make_classifier(model)
        
        clf.fit(X_train, y_train)
        predictions = clf.predict(X_test)
        
        accuracy, precision, recall, f1, roc_auc = performance(X_test, y_test, clf, predictions)

        ## validation on hold out set
        y_actual = hold_out['outcome'].to_numpy()
        X = hold_out[embed_col_names].to_numpy()
        predictions = clf.predict(X)
        
        hold_accuracy, hold_precision, hold_recall, hold_f1, hold_roc_auc = performance(X, y_actual, clf, predictions)
        
        results.append([condition, model, accuracy, precision, recall, f1, roc_auc,
                        hold_accuracy, hold_precision, hold_recall, hold_f1, hold_roc_auc])

Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [116]:
(pl.DataFrame(results, schema=['condition', 'model', 'accuracy', 'precision', 'recall', 'f1', 'roc_auc',
                               'hold_accuracy', 'hold_precision', 'hold_recall', 'hold_f1', 'hold_roc_auc'])
        .with_columns(
            pl.lit(mode).alias('mode')
            )
                               )

condition,model,accuracy,precision,recall,f1,roc_auc,hold_accuracy,hold_precision,hold_recall,hold_f1,hold_roc_auc,mode
str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str
"""acute myocardi…","""logistic""",0.733333,0.714286,0.714286,0.714286,0.892857,0.8,0.75,0.9,0.818182,0.62,"""basic"""
"""gi bleed""","""logistic""",0.6,0.571429,0.571429,0.571429,0.571429,1.0,1.0,1.0,1.0,0.74,"""basic"""
"""acute liver in…","""logistic""",0.733333,0.666667,0.857143,0.75,0.714286,0.85,0.818182,0.9,0.857143,0.58,"""basic"""
"""acute kidney i…","""logistic""",0.533333,0.5,0.714286,0.588235,0.625,0.6,0.5625,0.9,0.692308,0.93,"""basic"""
"""acute myocardi…","""rf""",0.6,0.571429,0.571429,0.571429,0.6875,0.8,0.75,0.9,0.818182,0.595,"""basic"""
"""gi bleed""","""rf""",0.6,0.571429,0.571429,0.571429,0.589286,1.0,1.0,1.0,1.0,0.665,"""basic"""
"""acute liver in…","""rf""",0.666667,0.625,0.714286,0.666667,0.794643,0.85,0.818182,0.9,0.857143,0.475,"""basic"""
"""acute kidney i…","""rf""",0.533333,0.5,0.571429,0.533333,0.642857,0.6,0.5625,0.9,0.692308,0.98,"""basic"""
"""acute myocardi…","""knn""",0.666667,0.583333,1.0,0.736842,0.732143,0.8,0.75,0.9,0.818182,0.59,"""basic"""
"""gi bleed""","""knn""",0.666667,0.583333,1.0,0.736842,0.732143,1.0,1.0,1.0,1.0,0.605,"""basic"""
