# Keyword-Extraction using BERT

Use BERT Token Classification Model to extract keyword tokens from a sentence.

## Prepare Dataset for BERT.

In [1]:
import numpy as np
import copy
import os
import re
from nltk.tokenize import sent_tokenize

In [2]:
train_path = "/Users/dunbanghe/Desktop/BERT/train"
validation_path = "/Users/dunbanghe/Desktop/BERT/validation"

In [3]:
tr_txt = sorted([f for f in os.listdir(train_path) if f.endswith(".abstr")])
tr_key = sorted([f for f in os.listdir(train_path) if f.endswith(".uncontr")])

val_txt = sorted([f for f in os.listdir(validation_path) if f.endswith(".abstr")])
val_key = sorted([f for f in os.listdir(validation_path) if f.endswith(".uncontr")])

In [4]:
tr_filekey = dict()
for i, k in enumerate(tr_txt):
    tr_filekey[tr_key[i]] = k

val_filekey = dict()
for i, k in enumerate(val_txt):
    val_filekey[val_key[i]] = k

#filekey

In [5]:
def convert(key, filekey, path):
    sentences = ""
    
    for line in open(path + "/" + filekey[key], 'r'):
        sentences += (" " + line.rstrip())
        sentences = re.sub(r'\t', '', sentences)
        
    tokens = sent_tokenize(sentences)
    #print(tokens)
    
    key_file = open(path + "/" + str(key),'r').read().rstrip()
    key_file = re.sub('\s+',' ',key_file).lower()
    #keys = [line.strip() for line in key_file]
    keys = key_file.split("; ")
    #print(keys)
    
    key_sent = []
    labels = []
    for token in tokens:
        z = ['O'] * len(token.split())
        for k in keys:
            if k in token:
                
                if len(k.split())==1:
                    try:
                        z[token.lower().split().index(k.lower().split()[0])] = 'B'
                    except ValueError:
                        continue
                elif len(k.split())>1:
                    try:
                        if token.lower().split().index(k.lower().split()[0]) and token.lower().split().index(k.lower().split()[-1]):
                            z[token.lower().split().index(k.lower().split()[0])] = 'B'
                            for j in range(1, len(k.split())):
                                z[token.lower().split().index(k.lower().split()[j])] = 'I'
                    except ValueError:
                        continue
        for m, n in enumerate(z):
            if z[m] == 'I' and z[m-1] == 'O':
                z[m] = 'O'

        if set(z) != {'O'}:
            labels.append(z) 
            key_sent.append(token)
    return key_sent, labels

In [6]:
tr_sentences_ = []
tr_labels_ = []
val_sentences_ = []
val_labels_ = []

for tr_key, tr_value in tr_filekey.items():
    s, l = convert(tr_key,tr_filekey,train_path)
    tr_sentences_.append(s)
    tr_labels_.append(l)
tr_sentences = [item for sublist in tr_sentences_ for item in sublist]
tr_labels = [item for sublist in tr_labels_ for item in sublist]

#print(tr_sentences)
print(len(tr_sentences), len(tr_labels))


for val_key, val_value in val_filekey.items():
    s, l = convert(val_key,val_filekey,validation_path)
    val_sentences_.append(s)
    val_labels_.append(l)
val_sentences = [item for sublist in val_sentences_ for item in sublist]
val_labels = [item for sublist in val_labels_ for item in sublist]

print(len(val_sentences), len(val_labels))



3134 3134
1452 1452


## Model

In [7]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam

Using TensorFlow backend.


In [8]:
MAX_LEN = 75
bs = 32

In [9]:
tag2idx = {'B': 0, 'I': 1, 'O': 2}
tags_vals = ['B', 'I', 'O']

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [43]:
tokenized_texts_tr = [tokenizer.tokenize(sent) for sent in tr_sentences]
tokenized_texts_val = [tokenizer.tokenize(sent) for sent in val_sentences]
#len(tokenized_texts_val)

1452

In [13]:
tr_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts_tr],
                          maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

tr_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in tr_labels],
                     maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                     dtype="long", truncating="post")

val_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts_val],
                          maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

val_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in val_labels],
                     maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                     dtype="long", truncating="post")


In [14]:
tr_attention_masks = [[float(i>0) for i in ii] for ii in tr_input_ids]
val_attention_masks = [[float(i>0) for i in ii] for ii in val_input_ids]

In [45]:
#tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(tr_input_ids, tr_tags, random_state=2018, test_size=0.1)
#tr_masks, val_masks, _, _ = train_test_split(tr_attention_masks, tr_input_ids, random_state=2018, test_size=0.1)

tr_inputs = tr_input_ids
tr_tags = tr_tags
tr_masks = tr_attention_masks

val_inputs = val_input_ids
#print(len(val_inputs))
val_tags = val_tags
val_masks = val_attention_masks

1452


In [16]:
tr_inputs = torch.tensor(tr_inputs)
tr_tags = torch.tensor(tr_tags)
tr_masks = torch.tensor(tr_masks)

val_inputs = torch.tensor(val_inputs)
val_tags = torch.tensor(val_tags)
val_masks = torch.tensor(val_masks)

In [17]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [18]:
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(tag2idx))

In [19]:
model = model.cpu()

In [20]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)

In [21]:
from seqeval.metrics import f1_score

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [22]:
epochs = 4
max_grad_norm = 1.0

for _ in trange(epochs, desc="Epoch"):
    # TRAIN loop
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # forward pass
        loss = model(b_input_ids, token_type_ids=None,
                     attention_mask=b_input_mask, labels=b_labels)
        # backward pass
        loss.backward()
        # track train loss
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        model.zero_grad()
    # print train loss per epoch
    print("Train loss: {}".format(tr_loss/nb_tr_steps))
    torch.save(model.state_dict(), 'checkpoint.pth')
    
    # VALIDATION on validation set
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                                  attention_mask=b_input_mask, labels=b_labels)
            logits = model(b_input_ids, token_type_ids=None,
                           attention_mask=b_input_mask)
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.append(label_ids)
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
    
    eval_loss = eval_loss/nb_eval_steps
    print("Validation loss: {}".format(eval_loss))
    #print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]
    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]
    #print("F1-Score: {}".format(f1_score(pred_tags, valid_tags)))

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Train loss: 0.36790759040384874


Epoch:  25%|██▌       | 1/4 [22:47<1:08:23, 1367.81s/it]

Validation loss: 0.3278063898501189
Train loss: 0.29789082067353384


Epoch:  50%|█████     | 2/4 [48:05<47:05, 1412.72s/it]  

Validation loss: 0.30948931270319485
Train loss: 0.25308880498822856


Epoch:  75%|███████▌  | 3/4 [1:12:39<23:51, 1431.24s/it]

Validation loss: 0.28731507872757706
Train loss: 0.2121264496628119


Epoch: 100%|██████████| 4/4 [1:41:25<00:00, 1521.35s/it]

Validation loss: 0.2890006410686866





In [23]:
with open('pred_2.txt', 'w') as f:
    print(pred_tags, file=f)
with open('true_2.txt', 'w') as f:
    print(valid_tags, file=f)

In [24]:
print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))

Validation Accuracy: 0.9390126811594202


.

In [31]:
# Evaluation
def recall(kp_true, kp_predicted):
    tp = 0
    fn = 0 
    for i in kp_true:
        if (i in kp_predicted):
            tp = tp + 1
        else:
            fn = fn + 1
    return float(tp)/(float(tp)+float(fn)) if (float(tp)+float(fn))!=0 else 0


def precision(kp_true, kp_predicted):
    tp = 0
    fp = 0
    for i in kp_predicted:
        if(i in kp_true):
            tp = tp + 1
        else:
            fp  = fp + 1
    
    return float(tp)/(float(tp)+float(fp))  if (float(tp)+float(fp)) > 0 else 0

def f1(kp_true, kp_predicted):
    precision_ = precision(kp_true, kp_predicted)
    recall_ = recall(kp_true, kp_predicted)
    return (2 * (precision_ * recall_)) / (precision_ + recall_) if precision_ + recall_ > 0 else 0


def retrive_phrase(tags_predicted, document_eng):
    #from tag 0,1,2 to keyphrase arrays (2d-sentences, words)
    kp = []
    sentence= []
    for i in range(len(tags_predicted)):
        if (tags_predicted[i] == 'O'):
            if len(sentence) != 0:
                kp.append(copy.deepcopy(sentence))
                
            sentence.clear()#we know a sentence ends when we encounter 0, so push
        elif((tags_predicted[i] == 'B')):
            if len(sentence) != 0:
                kp.append(copy.deepcopy(sentence))
            sentence.clear()#we know a sentence ends when we encounter 1, so push
 
            sentence.append(document_eng[i])
            if(i== len(tags_predicted)-1):#if it is the last element, we push
                kp.append(copy.deepcopy(sentence))
        else:
            sentence.append(document_eng[i])
            if(i== len(tags_predicted)-1):
                kp.append(copy.deepcopy(sentence))
    return kp

In [61]:
#len(pred_tags)

In [62]:
#len(tokenized_texts_val)

In [63]:
#import operator
#from functools import reduce
#len(reduce(operator.add, tokenized_texts_val))
#len(reduce(operator.add, predictions))

In [64]:
#retrive_phrase(pred,true)

In [65]:
#len(predictions)