In [10]:
from ptic import pmi_tfidf_classifier as ptah
import numpy as np
from nltk.tokenize import word_tokenize
import pandas as pd
import nltk
from time import time
import string
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix

stop_words = set(nltk.corpus.stopwords.words('english') + list(string.punctuation))

def tokenize(s):
    return [i for i in word_tokenize(s.lower()) if i not in stop_words]

def tokenization(train_data, var_name='abstract'):
    tokenized_texts = []
    #print("Tokenization....")
    for _, row in train_data.iterrows():
        text = str(row[var_name])
        text = str(row['title']) + ' ' + str(row['abstract']) + ' ' + str(row['tox_annotation'])
        words = tokenize(text)
        tokenized_texts.append(words)
    return tokenized_texts

In [12]:
path = '../datasets/DILI/merged_additional_data_dili_cleaned.csv'

data_raw = pd.read_csv(path)
indices = np.random.permutation(data_raw.index)
data = data_raw.loc[indices]
data = data_raw.sample(frac=1)

idx = int(data.shape[0] * 0.1)
test_data = data.iloc[:idx]
train_data = data.iloc[idx:]
targets_train = train_data['label'].values
targets_test = test_data['label'].values
train_data

s1 = time()
tokenized_texts = tokenization(train_data)
N = len(tokenized_texts)
word2text_count = ptah.get_word_stat(tokenized_texts)
words_pmis = ptah.create_pmi_dict(tokenized_texts, targets_train, min_count=5)
e1 = time()

s2 = time()
tokenized_test_texts = tokenization(test_data)
results = ptah.classify_pmi_based(words_pmis, word2text_count, tokenized_test_texts, N)
e2 = time()

print('trainin time (min):', (e1 - s1) / 60)
print('testing time (min):', (e2 - s2) / 60)

trainin time (min): 0.5511654257774353
testing time (min): 0.066676394144694


In [13]:
tn, fp, fn, tp = confusion_matrix(results, targets_test).ravel()
print('accuracy:', accuracy_score(results, targets_test))
print('precision:', precision_score(results, targets_test))
print('recall:', recall_score(results, targets_test))
print('f1_score:', f1_score(results, targets_test))
print('fp_rate:', fp / (fp + tn))
print('fn_rate:', fn / (fn + tp))

accuracy: 0.9337436640115858
precision: 0.9650756693830035
recall: 0.8441955193482689
f1_score: 0.9005975013579578
fp_rate: 0.016853932584269662
fn_rate: 0.15580448065173116
