# Keyword-Extraction using BERT

Use BERT Token Classification Model to extract keyword tokens from a sentence.

## Prepare Dataset for BERT.

Convert Sem-Eval 2010 keyword recognition dataset to BIO format dataset.

In [1]:
import os
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
from pytorch_pretrained_bert import BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
train_path = "maui-semeval2010-train"

In [3]:
txt = sorted([f for f in os.listdir(train_path) if not f.endswith("-justTitle.txt") and not f.endswith(".key") and not f.endswith("-CrowdCountskey")])
key = sorted([f for f in os.listdir(train_path) if f.endswith(".key")])

In [4]:
filekey = dict()
for i, k in enumerate(txt):
    filekey[key[i]] = k

In [5]:
def convert(key):
    sentences = ""
    for line in open(train_path + "/" + filekey[key], 'r'):
        sentences += (" " + line.rstrip())
    tokens = sent_tokenize(sentences)
    key_file = open(train_path + "/" + str(key),'r')
    keys = [line.strip() for line in key_file]
    key_sent = []
    labels = []
    for token in tokens:
        z = ['O'] * len(token.split())
        for k in keys:
            if k in token:
                
                if len(k.split())==1:
                    try:
                        z[token.lower().split().index(k.lower().split()[0])] = 'B'
                    except ValueError:
                        continue
                elif len(k.split())>1:
                    try:
                        if token.lower().split().index(k.lower().split()[0]) and token.lower().split().index(k.lower().split()[-1]):
                            z[token.lower().split().index(k.lower().split()[0])] = 'B'
                            for j in range(1, len(k.split())):
                                z[token.lower().split().index(k.lower().split()[j])] = 'I'
                    except ValueError:
                        continue
        for m, n in enumerate(z):
            if z[m] == 'I' and z[m-1] == 'O':
                z[m] = 'O'

        if set(z) != {'O'}:
            labels.append(z) 
            key_sent.append(token)
    return key_sent, labels

In [20]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
ave_len = 0
max_len = 0

def index_multi(l, x):
    for i, _x in enumerate(l):
        if _x == x:
            yield i

def my_convert(key):
    doc = ""
    with open(train_path + "/" + filekey[key], 'r') as f:
        doc = ' '.join(map(lambda line: line.rstrip().lower(), f.readlines()))
    sents = sent_tokenize(doc)
    # sents = map(lambda sent: bert_tokenizer.tokenize(sent), sents)
    with open(train_path + "/" + str(key),'r') as f:
        keys = [line.strip().lower() for line in f]
    
    key_sent = []
    labels = []
    for sent in sents:
        tsent = bert_tokenizer.tokenize(sent)
        label = ['O'] * len(tsent)
        for key in keys:
            if key in sent:
                tkey = bert_tokenizer.tokenize(key)
                key_length = len(tkey)
                for b_id in index_multi(tsent, tkey[0]):
                    if tsent[b_id + key_length - 1:] and tsent[b_id + key_length - 1].startswith(tkey[-1]):
                        label[b_id] = 'B' ## Begin
                        for i in range(b_id+1, b_id+len(k.split())):
                            label[i] = 'I'  ## Inner

        if set(label) != {'O'}:
            labels.append(label) 
            key_sent.append(sent)
            global max_len, ave_len
            if len(tsent) > max_len:
                max_len = len(tsent)
            ave_len += len(tsent)
    return key_sent, labels

In [21]:
sentences_ = []
labels_ = []
for key, value in tqdm(filekey.items()):
    # s, l = convert(key)  ## -> 8584 pairs
    s, l = my_convert(key)  ## -> 15486 pairs
    sentences_.append(s)
    labels_.append(l)
sentences = [item for sublist in sentences_ for item in sublist]
labels = [item for sublist in labels_ for item in sublist]
print(len(sentences), len(labels), max_len, ave_len/len(sentences), )

100%|██████████| 144/144 [00:24<00:00,  5.96it/s]

15486 15486 1022 38.97836755779414





In [17]:
max_len

def print_len():
    print(max_len)
    
print_len()

0


## Model Train

In [10]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam

Using TensorFlow backend.


In [12]:
MAX_LEN = 75
bs = 32

tag2idx = {'B': 0, 'I': 1, 'O': 2}
tags_vals = ['B', 'I', 'O']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(device, n_gpu)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
tokenized_texts = [tokenizer.tokenize(sent) for sent in tqdm(sentences)]

cuda 4


100%|██████████| 15486/15486 [00:09<00:00, 1702.71it/s]


In [47]:
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                          maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],
                     maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                     dtype="long", truncating="post")
attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

In [48]:
tr_inputs, val_inputs, tr_tags, val_tags, tr_masks, val_masks = \
    train_test_split(input_ids, tags, attention_masks, random_state=2018, test_size=0.1)

In [49]:
tr_inputs = torch.tensor(tr_inputs)
tr_tags = torch.tensor(tr_tags)
tr_masks = torch.tensor(tr_masks)
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

val_inputs = torch.tensor(val_inputs)
val_tags = torch.tensor(val_tags)
val_masks = torch.tensor(val_masks)
valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [50]:
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(tag2idx))
model = model.cuda()

In [51]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)

In [52]:
from seqeval.metrics import f1_score

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [53]:
epochs = 4
max_grad_norm = 1.0

for _ in trange(epochs, desc="Epoch"):
    # TRAIN loop
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # forward pass
        loss = model(b_input_ids, token_type_ids=None,
                     attention_mask=b_input_mask, labels=b_labels)
        # backward pass
        loss.backward()
        # track train loss
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        model.zero_grad()
    # print train loss per epoch
    print("Train loss: {}".format(tr_loss/nb_tr_steps))
    # VALIDATION on validation set
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                                  attention_mask=b_input_mask, labels=b_labels)
            logits = model(b_input_ids, token_type_ids=None,
                           attention_mask=b_input_mask)
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.append(label_ids)
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss/nb_eval_steps
    print("Validation loss: {}".format(eval_loss))
    print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]
    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]
    print("F1-Score: {}".format(f1_score(pred_tags, valid_tags)))

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Train loss: 0.044464042709303014
Validation loss: 0.020887882524759184
Validation Accuracy: 0.9916136839351124


Epoch:  25%|██▌       | 1/4 [02:11<06:35, 131.68s/it]

F1-Score: 0.7871783689489752
Train loss: 0.01776634122526974


Epoch:  50%|█████     | 2/4 [04:26<04:25, 132.64s/it]

Validation loss: 0.015417849283893498
Validation Accuracy: 0.9940528519099947
F1-Score: 0.8479544957339751
Train loss: 0.010883820699228811
Validation loss: 0.014252790297400586
Validation Accuracy: 0.9946866823652536


Epoch:  75%|███████▌  | 3/4 [06:41<02:13, 133.22s/it]

F1-Score: 0.8700398908251102
Train loss: 0.006929780181695989


Epoch: 100%|██████████| 4/4 [08:56<00:00, 134.07s/it]

Validation loss: 0.015607339843195312
Validation Accuracy: 0.9950045787545788
F1-Score: 0.8748398120461341





In [54]:
torch.save(model, 'model.pt')

## Model Inference

In [55]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer

In [57]:
model = torch.load('model.pt')
model.eval()
predictions = []
true_labels = []
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for batch in valid_dataloader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch

    with torch.no_grad():
        tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                              attention_mask=b_input_mask, labels=b_labels)
        logits = model(b_input_ids, token_type_ids=None,
                       attention_mask=b_input_mask)
        
    logits = logits.detach().cpu().numpy()
    predictions.extend([list(p) for p in np.argmax(logits, axis=2)])

    label_ids = b_labels.to('cpu').numpy()
    true_labels.append(label_ids)
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)

    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy

    nb_eval_examples += b_input_ids.size(0)
    nb_eval_steps += 1

pred_tags = [[tags_vals[p_i] for p_i in p] for p in predictions]
valid_tags = [[tags_vals[l_ii] for l_ii in l_i] for l in true_labels for l_i in l ]
print("Validation loss: {}".format(eval_loss/nb_eval_steps))
print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
print("Validation F1-Score: {}".format(f1_score(pred_tags, valid_tags)))

Validation loss: 0.015607339843195312
Validation Accuracy: 0.9950045787545788
Validation F1-Score: 0.8748398120461341


### Get keywords from sentence

In [58]:
def keywordextract(sentence):
    text = sentence
    tkns = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tkns)
    segments_ids = [0] * len(tkns)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    segments_tensors = torch.tensor([segments_ids]).to(device)
    model.eval()
    prediction = []
    logit = model(tokens_tensor, token_type_ids=None,
                                  attention_mask=segments_tensors)
    logit = logit.detach().cpu().numpy()
    print(logit)
    prediction.extend([list(p) for p in np.argmax(logit, axis=2)])
    print(prediction)
    for k, j in enumerate(prediction[0]):
        if j==1 or j==0:
            print(tokenizer.convert_ids_to_tokens(tokens_tensor[0].to('cpu').numpy())[k], j)

In [59]:
text = "We present the Insertion Transformer"

In [60]:
keywordextract(text)

[[[-3.7819104 -4.384814   7.748294 ]
  [-3.840575  -4.3448853  7.7474585]
  [-4.5531235 -3.8311753  7.5557384]
  [ 4.856516  -3.870973  -1.3971982]
  [ 0.3088797 -5.191846   4.0787377]
  [-4.074701  -4.058109   7.4975424]]]
[[2, 2, 2, 0, 2, 2]]
insertion 0
