## Domain:
* Natural language processing (NLP)

## Data source:
* PubMed abstracts (n = 10,000)
* From PubMed 200k RCT: https://arxiv.org/pdf/1710.06071.pdf
* https://github.com/Franck-Dernoncourt/pubmed-rct


## Prediction task:
* Given a random sentence from a PubMed abstract, predict which abstract section the sentence came from:
    * Background
    * Objective
    * Methods
    * Results
    * Conclusions
* Multiclass (5) classification

## Model:
* Recurrent neural network (RNN) with gated recurrent units (GRU)

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import sklearn.utils
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score

from keras.models import Sequential
from keras.layers import GRU, Dense,  Masking, Dropout, Embedding
from keras.preprocessing.sequence import pad_sequences
from keras.callbacks import EarlyStopping, ModelCheckpoint

Using TensorFlow backend.


## 1. Load data

Source: https://raw.githubusercontent.com/Franck-Dernoncourt/pubmed-rct/master/PubMed_20k_RCT/train.txt

#### Examine raw text file

In [2]:
with open('data/pubmed.txt', 'r') as f:
    for line in f.readlines()[:5]:
        print(line)

###24845963

BACKGROUND	This study analyzed liver function abnormalities in heart failure patients admitted with severe acute decompensated heart failure ( ADHF ) .

RESULTS	A post hoc analysis was conducted with the use of data from the Evaluation Study of Congestive Heart Failure and Pulmonary Artery Catheterization Effectiveness ( ESCAPE ) .

RESULTS	Liver function tests ( LFTs ) were measured at 7 time points from baseline , at discharge , and up to 6 months follow-up .

RESULTS	Survival analyses were used to assess the association between admission Model of End-Stage Liver Disease Excluding International Normalized Ratio ( MELD-XI ) scores and patient outcome.There was a high prevalence of abnormal baseline ( admission ) LFTs ( albumin 23.8 % , aspartate transaminase 23.5 % , alanine transaminase 23.8 % , and total bilirubin 36.1 % ) .



#### Load into dataframe

In [3]:
df = pd.read_csv('data/pubmed.txt',
                 sep='\t',
                 header=None,
                 skiprows=1,
                 names=['LABEL', 'TEXT'])
df[:12]

Unnamed: 0,LABEL,TEXT
0,BACKGROUND,This study analyzed liver function abnormaliti...
1,RESULTS,A post hoc analysis was conducted with the use...
2,RESULTS,Liver function tests ( LFTs ) were measured at...
3,RESULTS,Survival analyses were used to assess the asso...
4,RESULTS,The percentage of patients with abnormal LFTs ...
5,RESULTS,When mean hemodynamic profiles were compared i...
6,RESULTS,Multivariable analyses revealed that patients ...
7,CONCLUSIONS,Abnormal LFTs are common in the ADHF populatio...
8,CONCLUSIONS,Elevated MELD-XI scores are associated with po...
9,###24469619,


#### Remove metadata rows ('#24845963', etc.)

In [4]:
df = df[df['LABEL'].isin(['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS'])].reset_index(drop=True)
df[:5]

Unnamed: 0,LABEL,TEXT
0,BACKGROUND,This study analyzed liver function abnormaliti...
1,RESULTS,A post hoc analysis was conducted with the use...
2,RESULTS,Liver function tests ( LFTs ) were measured at...
3,RESULTS,Survival analyses were used to assess the asso...
4,RESULTS,The percentage of patients with abnormal LFTs ...


#### Examine distribution of labels

In [5]:
df['LABEL'].value_counts()

METHODS        9897
RESULTS        9713
CONCLUSIONS    4571
BACKGROUND     3621
OBJECTIVE      2333
Name: LABEL, dtype: int64

#### Build example balanced dataset (2,000 sentences per label)

In [6]:
df = pd.concat([df[df['LABEL'] == label].sample(2000) for label in ['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS']])
df[:5]

Unnamed: 0,LABEL,TEXT
5626,BACKGROUND,"In recent years , the surgical step-up approac..."
17991,BACKGROUND,Peripheral opioid receptor targeting has been ...
3338,BACKGROUND,A constructive safety culture is essential for...
13095,BACKGROUND,NCT00679809 .
19278,BACKGROUND,We aimed to compare gefitinib with placebo in ...


#### Shuffle rows

In [7]:
df = df.sample(frac=1).reset_index(drop=True)
df[:5]

Unnamed: 0,LABEL,TEXT
0,BACKGROUND,One hypothesis suggests that the differential ...
1,OBJECTIVE,To investigate whether learning basic life sup...
2,RESULTS,All patients received six doses of study medic...
3,RESULTS,No significant differences were found between ...
4,OBJECTIVE,The authors studied the immediate and long-ter...


#### Extract text (documents) and corresponding labels

In [8]:
documents = list(df['TEXT'].values)
documents[:5]

["One hypothesis suggests that the differential response to ondansetron - and serotonin-specific re-uptake inhibitors ( SSRIs ) may be due to a functional polymorphism of the 5 ' - HTTLPR promoter region in SLC6A4 , the gene that codes for the serotonin transporter ( 5-HTT ) .",
 'To investigate whether learning basic life support ( BLS ) and cardiopulmonary resuscitation ( CPR ) from video produce higher learning outcomes compared to pictures in reciprocal learning .',
 'All patients received six doses of study medication .',
 'No significant differences were found between the two study groups .',
 'The authors studied the immediate and long-term performance and complications of two twin-catheter systems , the Tesio catheter ( TC ) and the LifeCath Twin ( LC ) , to inform clinical practice .']

In [9]:
len(documents)

10000

In [10]:
labels = list(df['LABEL'].values)
labels[:5]

['BACKGROUND', 'OBJECTIVE', 'RESULTS', 'RESULTS', 'OBJECTIVE']

#### Convert string label to integer

In [11]:
UNIQUE_LABELS = ['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS']
y = [UNIQUE_LABELS.index(label) for label in labels]
y[:5]

[0, 1, 3, 3, 1]

## 2. Preprocess data

#### Lowercase all documents

In [12]:
documents = [document.lower() for document in documents]
documents[:5]

["one hypothesis suggests that the differential response to ondansetron - and serotonin-specific re-uptake inhibitors ( ssris ) may be due to a functional polymorphism of the 5 ' - httlpr promoter region in slc6a4 , the gene that codes for the serotonin transporter ( 5-htt ) .",
 'to investigate whether learning basic life support ( bls ) and cardiopulmonary resuscitation ( cpr ) from video produce higher learning outcomes compared to pictures in reciprocal learning .',
 'all patients received six doses of study medication .',
 'no significant differences were found between the two study groups .',
 'the authors studied the immediate and long-term performance and complications of two twin-catheter systems , the tesio catheter ( tc ) and the lifecath twin ( lc ) , to inform clinical practice .']

#### Tokenize each document (naive whitespace tokenization)

In [13]:
tokenized_documents = [document.split() for document in documents]
tokenized_documents[0]

['one',
 'hypothesis',
 'suggests',
 'that',
 'the',
 'differential',
 'response',
 'to',
 'ondansetron',
 '-',
 'and',
 'serotonin-specific',
 're-uptake',
 'inhibitors',
 '(',
 'ssris',
 ')',
 'may',
 'be',
 'due',
 'to',
 'a',
 'functional',
 'polymorphism',
 'of',
 'the',
 '5',
 "'",
 '-',
 'httlpr',
 'promoter',
 'region',
 'in',
 'slc6a4',
 ',',
 'the',
 'gene',
 'that',
 'codes',
 'for',
 'the',
 'serotonin',
 'transporter',
 '(',
 '5-htt',
 ')',
 '.']

#### Other optional methods (not included in this example):
* Preserving original capitalization  
  
* Removing punctuation

* Splitting hyphenated words
    * "health-based" --> "health", "based"
    
* Replacing numbers with special tokens
    * "1.4 mcg" --> "&lt;NUM&gt; mcg"
    
* Replacing URL links with special tokens
    
* Removing stop words
    * "the", "a", "to", "of", etc.

* Stemming
    * "working" --> "work"
    * "studies" --> "studi"
    * "studying" --> "study"
* Lemmatizing
    * "am", "are", "is" --> "be"
    * "studies" --> "study"
    * "studying" --> "study"
    * "better" --> "good"
    
* More complex tokenization

## 3. Convert text to word indices

#### Build vocabulary

In [14]:
token2idx = {'<PAD_TOKEN>': 0}
idx2token = ['<PAD_TOKEN>']

for tokenized_document in tokenized_documents:
    for token in tokenized_document:
        if token not in token2idx:
            token2idx[token] = len(token2idx)
            idx2token.append(token)
            
print('# Unique tokens in vocabulary = %d' % len(token2idx))

# Unique tokens in vocabulary = 20331


In [15]:
{k: v for (k,v) in list(token2idx.items())[:10]}

{'<PAD_TOKEN>': 0,
 'one': 1,
 'hypothesis': 2,
 'suggests': 3,
 'that': 4,
 'the': 5,
 'differential': 6,
 'response': 7,
 'to': 8,
 'ondansetron': 9}

In [16]:
idx2token[4]

'that'

In [17]:
idx2token[9]

'ondansetron'

In [18]:
idx2token[0]

'<PAD_TOKEN>'

#### Replace text tokens with integer indices

In [19]:
X = []
for tokenized_document in tokenized_documents:
    X.append([token2idx[token] for token in tokenized_document])

In [20]:
X[:5]

[[1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  8,
  21,
  22,
  23,
  24,
  5,
  25,
  26,
  10,
  27,
  28,
  29,
  30,
  31,
  32,
  5,
  33,
  4,
  34,
  35,
  5,
  36,
  37,
  15,
  38,
  17,
  39],
 [8,
  40,
  41,
  42,
  43,
  44,
  45,
  15,
  46,
  17,
  11,
  47,
  48,
  15,
  49,
  17,
  50,
  51,
  52,
  53,
  42,
  54,
  55,
  8,
  56,
  30,
  57,
  42,
  39],
 [58, 59, 60, 61, 62, 24, 63, 64, 39],
 [65, 66, 67, 68, 69, 70, 5, 71, 63, 72, 39],
 [5,
  73,
  74,
  5,
  75,
  11,
  76,
  77,
  11,
  78,
  24,
  71,
  79,
  80,
  32,
  5,
  81,
  82,
  15,
  83,
  17,
  11,
  5,
  84,
  85,
  15,
  86,
  17,
  32,
  8,
  87,
  88,
  89,
  39]]

## 4. Prepare data

#### Examine distribution of sequence lengths

In [21]:
sequence_lengths = [len(x) for x in X]
pd.Series(sequence_lengths).describe()

count    10000.000000
mean        25.417900
std         14.265715
min          1.000000
25%         16.000000
50%         23.000000
75%         32.000000
max        211.000000
dtype: float64

#### Split into train, validation, and test sets

In [22]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, shuffle=True)

print('# Train documents = %d' % len(y_train))
print('# Validation documents = %d' % len(y_val))
print('# Test documents = %d' % len(y_test))

# Train documents = 6400
# Validation documents = 1600
# Test documents = 2000


## 5. RNN (GRU)

#### Model using Keras Sequential API

In [23]:
def build_model(vocab_size=len(token2idx), n_classes=5, embedding_dim=100, gru_dim=32, dropout=0.2, optimizer='adam'):
    model = Sequential()
    model.add(Embedding(input_dim=vocab_size,
                        output_dim=embedding_dim,
                        mask_zero=True))
    model.add(GRU(units=gru_dim))
    model.add(Dropout(rate=dropout))
    model.add(Dense(n_classes, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['accuracy'])
    return model

#### Batch generator for batch-wise padding

In [24]:
def batch_generator(X, y, batch_size=64, shuffle=False):
    while True:
        if shuffle:
            X, y = sklearn.utils.shuffle(X, y)
            
        for idx in range(0, len(X), batch_size):
            X_batch = X[idx:idx + batch_size]
            y_batch = y[idx:idx + batch_size]
            
            batch_sequence_lengths = [len(x) for x in X_batch]
            X_batch = pad_sequences(sequences=X_batch,
                                    maxlen=max(batch_sequence_lengths),
                                    padding='post',
                                    value=0)
            yield X_batch, y_batch

#### Train model

In [25]:
batch_size = 64

n_train_batches = int(np.ceil(len(X_train) / batch_size))
n_val_batches = int(np.ceil(len(X_val) / batch_size))
n_test_batches = int(np.ceil(len(X_test) / batch_size))

train_generator = batch_generator(X_train, y_train)
val_generator = batch_generator(X_val, y_val)
test_generator = batch_generator(X_test, y_test)

print('# train batches = %d' % n_train_batches)
print('# val batches = %d' % n_val_batches)
print('# test batches = %d' % n_test_batches)

# train batches = 100
# val batches = 25
# test batches = 32


In [26]:
model = build_model()

model.fit_generator(generator=train_generator,
                    steps_per_epoch=n_train_batches,
                    validation_data=val_generator,
                    validation_steps=n_val_batches,
                    epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fa4f0477be0>

## 6. Evaluation

#### Get class prediction probabilities for test set

In [27]:
y_pred_proba = model.predict_generator(generator=test_generator,
                                       steps=n_test_batches)

In [28]:
y_pred_proba[:5]

array([[9.6741875e-05, 2.1852116e-04, 9.8095691e-01, 3.9991997e-03,
        1.4728604e-02],
       [1.2877173e-03, 2.3125400e-04, 7.8454393e-01, 1.9392265e-03,
        2.1199793e-01],
       [8.5129309e-04, 8.2725929e-03, 8.9508820e-01, 7.7318639e-02,
        1.8469281e-02],
       [6.2861266e-03, 7.9566188e-02, 5.1076740e-01, 3.8030130e-01,
        2.3078980e-02],
       [2.6253971e-01, 7.1709919e-01, 5.9131352e-04, 2.1120450e-03,
        1.7657692e-02]], dtype=float32)

#### Convert class probabilities to integer label prediction

In [29]:
y_pred = np.argmax(y_pred_proba, axis=1)

In [30]:
y_pred[:5]

array([2, 2, 2, 2, 1])

#### Results

In [31]:
accuracy = accuracy_score(y_test, y_pred)
print('Test set accuracy = %.3f' % accuracy)
print()
print(classification_report(y_test, y_pred, target_names=UNIQUE_LABELS))

Test set accuracy = 0.577

              precision    recall  f1-score   support

  BACKGROUND       0.45      0.55      0.49       391
   OBJECTIVE       0.57      0.52      0.54       394
     METHODS       0.59      0.72      0.65       381
     RESULTS       0.80      0.63      0.70       413
 CONCLUSIONS       0.53      0.48      0.51       421

   micro avg       0.58      0.58      0.58      2000
   macro avg       0.59      0.58      0.58      2000
weighted avg       0.59      0.58      0.58      2000

