In [2]:
from itertools import chain
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelBinarizer
import sklearn
import pycrfsuite
from jinja2 import Template

In [3]:
with open('sample_dataset', 'r') as reader:
    text = reader.read().splitlines()

sent = []
train_set = []
for word_labels in text:
    line = tuple(word_labels.split())
    #print(line)    
    if line == ():
        train_set.append(sent)
        sent = []
    else:
        sent.append(line)


In [4]:
print(train_set[0])

[('Which', 'O'), ('patients', 'O'), ('are', 'O'), ('registered', 'O'), ('in', 'O'), ('R3500-AD-1906', 'STUDYID'), ('?', 'O')]


In [7]:
def word2features(sent, i):
    word = sent[i][0]
    features = [
        'bias',
        'word.lower=' + word.lower(),
        'word[-3:]=' + word[-3:],
        'word[-2:]=' + word[-2:],
        'word.isupper=%s' % word.isupper(),
        'word.istitle=%s' % word.istitle(),
        'word.isdigit=%s' % word.isdigit(),
        ]
    if i > 0:
        word1 = sent[i-1][0]
        features.extend([
            '-1:word.lower=' + word1.lower(),
            '-1:word.istitle=%s' % word1.istitle(),
            '-1:word.isupper=%s' % word1.isupper(),
        ])
    else:
        features.append('BOS')
        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        features.extend([
            '+1:word.lower=' + word1.lower(),
            '+1:word.istitle=%s' % word1.istitle(),
            '+1:word.isupper=%s' % word1.isupper(),
        ])
    else:
        features.append('EOS')
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, label in sent]

def sent2tokens(sent):
    return [token for token, label in sent]

In [None]:
sent2features(train_set[0])[1]

In [None]:
%%time
X_train = [sent2features(s) for s in train_set]
y_train = [sent2labels(s) for s in train_set]

In [None]:
_ = 0
print(X_train[_])
print(y_train[_])

In [None]:
%%time
trainer = pycrfsuite.Trainer(verbose=False)

for xseq, yseq in zip(X_train, y_train):
    trainer.append(xseq, yseq)

In [None]:
trainer.set_params({
    'c1': 1.0,   # coefficient for L1 penalty
    'c2': 1e-3,  # coefficient for L2 penalty
    'max_iterations': 500,  # stop earlier

    # include transitions that are possible, but not observed
    'feature.possible_transitions': True
})

In [None]:
trainer.params()

In [None]:
%%time
trainer.train('test.crfsuite')

In [None]:
trainer.logparser.last_iteration

In [5]:
tagger = pycrfsuite.Tagger()
tagger.open('test.crfsuite')

<contextlib.closing at 0x1b3cc5d8c70>

In [9]:
example_sent = train_set[1]
print(example_sent)
print(' '.join(sent2tokens(example_sent)), end='\n\n')
print("Predicted:", ' '.join(tagger.tag(sent2features(example_sent))))
print("Correct:  ", ' '.join(sent2labels(example_sent)))

[('What', 'O'), ('patients', 'O'), ('in', 'O'), ('R3500-AD-1906', 'STUDYID'), ('are', 'O'), ('from', 'O'), ('Germany', 'COUNTRY'), ('?', 'O')]
What patients in R3500-AD-1906 are from Germany ?

Predicted: O O O STUDYID O O COUNTRY O
Correct:   O O O STUDYID O O COUNTRY O


In [10]:
sql_template = Template("SELECT * from {{ table_name }} where")
sql_query = sql_template.render(table_name = "study_table")
print(type(sql_query))

<class 'str'>


In [11]:
predict_labels = tagger.tag(sent2features(example_sent))
print(predict_labels)
print(example_sent)

['O', 'O', 'O', 'STUDYID', 'O', 'O', 'COUNTRY', 'O']
[('What', 'O'), ('patients', 'O'), ('in', 'O'), ('R3500-AD-1906', 'STUDYID'), ('are', 'O'), ('from', 'O'), ('Germany', 'COUNTRY'), ('?', 'O')]


In [12]:
conditions = []
for word, label in zip(example_sent, predict_labels):
    #print(word[0], label)
    if label != str("O"):
        conditions.append(str(label) + '=' + str('"') + str(word[0]) + str('"'))

In [13]:
conditions

['STUDYID="R3500-AD-1906"', 'COUNTRY="Germany"']

In [14]:
final_sql_condition = " AND ".join(conditions)
print(final_sql_condition)

STUDYID="R3500-AD-1906" AND COUNTRY="Germany"


In [15]:
sql_query = sql_query + " " + final_sql_condition
print(sql_query)

SELECT * from study_table where STUDYID="R3500-AD-1906" AND COUNTRY="Germany"
