https://github.com/kuldeep7688/BioMedicalBertNer

# Utility Functions

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset
from seqeval.metrics import precision_score, recall_score, f1_score
from tqdm import tqdm

In [2]:
def _is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

class InputExample(object):
    """
    A single training/test example.
    """
    def __init__(self, guid, words=None, labels=None, sentence=None):
        """Contructs a InputExample object.
        Args:
            guid (TYPE): unique id for the example
            words (TYPE): the words of the sequence
            labels (TYPE): the labels for each work of the sentence
        """
        self.guid = guid
        self.words = words
        self.labels = labels
        self.sentence = sentence

        if self.words is None and self.sentence:
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            # split sentence on whitepsace so that different tokens may be attributed to their original positions
            for c in self.sentence:
                if _is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            self.words = doc_tokens
            if self.labels is None:
                self.labels = ["O"]*len(self.words)


class InputFeatures(object):
    """
    A sigle set of input features for an example.
    """
    def __init__(self, input_ids, input_mask, segment_ids, label_ids=None, token_to_orig_index=None, orig_to_token_index=None):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.token_to_orig_index = token_to_orig_index
        self.orig_to_token_index = orig_to_token_index

In [22]:
def read_examples_from_file(data_dir, mode, line_splitter="\t"):
    file_path = os.path.join(data_dir, "{}.tsv".format(mode))
    guid_index = 1
    examples = []
    with open(file_path, encoding="utf-8") as f:
        words = []
        labels = []
        for line in f:
            if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                if words:
                    examples.append(
                        InputExample(
                            guid="{}-{}".format(mode, guid_index),
                            words=words,
                            labels=labels
                        )
                    )
                    guid_index += 1
                    words = []
                    labels = []
            else:
                splits = line.split(line_splitter)
                words.append(splits[0])
                if len(splits) > 1:
                    labels.append(splits[-1].replace("\n", ""))
                else:
                    # examples could have no label for model == test
                    labels.append("O")
        if words:
            examples.append(
                InputExample(
                    guid="{}-{}".format(mode, guid_index),
                    words=words,
                    labels=labels
                )
            )
    return examples


def get_i_label(beginning_label, label_map):
    """To properly label segments of words broken by BertTokenizer=.
    """
    if "B-" in beginning_label:
        i_label = "I-" + beginning_label.split("B-")[-1]
        return i_label
    elif "I-" in beginning_label:
        i_label = "I-" + beginning_label.split("I-")[-1]
        return i_label
    else:
        return "O"

In [4]:
def convert_examples_to_features(
    examples, label_map, max_seq_length,
    tokenizer, label_end_token="<EOS>",
    pad_token_label_id=-1, mask_padding_with_zero=True,
    logger=None, summary_writer=None, mode=None
):
    """
    Prepare features to be given as input to Bert
    """
    features = []
    for (ex_index, example) in enumerate(examples):
        tokens = []
        label_ids = []
        token_to_orig_index = []
        orig_to_token_index = []
        for word_idx, (word, label) in enumerate(zip(example.words, example.labels)):
            orig_to_token_index.append(len(tokens))
            word_tokens = tokenizer.tokenize(word)
            if len(word_tokens) > 0:
                tokens.extend(word_tokens)
                # USe the real label id for the first token of the word, and
                # propagate I-tag for the splitted tokens
                label_ids.extend(
                    [label_map[label]] + [label_map[get_i_label(label, label_map)]]
                    * (len(word_tokens) - 1)
                )
            for tok in word_tokens:
                token_to_orig_index.append(word_idx)

        special_tokens_count = 2 # for bert cls sentence sep
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[:(max_seq_length - special_tokens_count)]
            label_ids = label_ids[:(max_seq_length - special_tokens_count)]

        tokens += [tokenizer.sep_token]
        label_ids += [pad_token_label_id]
        segment_ids = [0]*len(tokens)

        tokens = [tokenizer.cls_token] + tokens
        label_ids = [pad_token_label_id] + label_ids
        segment_ids = [0] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1]*len(input_ids)

        # Zero pad up to the sequence length
        padding_length = max_seq_length - len(input_ids)
        input_ids += [tokenizer.pad_token_id] * padding_length
        input_mask += [0] * padding_length
        segment_ids += [0] * padding_length
        label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length

        if ex_index < 5:
            if logger:
                logger.info("****** EXAMPLES ********")
                logger.info("guid: {}".format(example.guid))
                logger.info("tokens: {}".format(" ".join([str(x) for x in tokens])))
                logger.info("input ids : {}".format(" ".join([str(x) for x in input_ids])))
                logger.info("input_mask : {}".format(" ".join([str(x) for x in input_mask])))
                logger.info("segment_ids : {}".format(" ".join([str(x) for x in segment_ids])))
                logger.info("label_ids : {}".format(" ".join([str(x) for x in label_ids])))

            if summary_writer:
                summary_writer.add_text(mode, "guid: {}".format(example.guid), 0)
                summary_writer.add_text(mode, "tokens: {}".format(" ".join([str(x) for x in tokens])), 0)
                summary_writer.add_text(mode, "input ids : {}".format(" ".join([str(x) for x in input_ids])), 0)
                summary_writer.add_text(mode, "input_mask : {}".format(" ".join([str(x) for x in input_mask])), 0)
                summary_writer.add_text(mode, "segment_ids : {}".format(" ".join([str(x) for x in segment_ids])), 0)
                summary_writer.add_text(mode, "label_ids : {}".format(" ".join([str(x) for x in label_ids])), 0)

        features.append(
            InputFeatures(
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                label_ids=label_ids,
                token_to_orig_index=token_to_orig_index,
                orig_to_token_index=orig_to_token_index
            )
        )

    return features

In [5]:
def get_labels(path):
    if path:
        with open(path, "r") as f:
            labels = f.read().splitlines()

        if "O" not in labels:
            labels = ["O"] + labels
        return labels
    else:
        return None


def load_and_cache_examples(
    max_seq_length, tokenizer, label_map, pad_token_label_id,
    mode, data_dir=None, logger=None, summary_writer=None,
    sentence_list=None, return_features_and_examples=False
):
    "Loads data features from cache or dataset file"

    if sentence_list is None:
        if data_dir:
            print("Creating features from dataset file at {}".format(data_dir))
            examples = read_examples_from_file(data_dir, mode)
    else:
        # will mainly be used in
        examples = []
        for idx, sentence in enumerate(sentence_list):
            examples.append(
                InputExample(
                    guid=idx, words=None, labels=None, sentence=sentence
                )
            )



    features = convert_examples_to_features(
        examples=examples, label_map=label_map,
        max_seq_length=max_seq_length, mode=mode,
        tokenizer=tokenizer,
        pad_token_label_id=pad_token_label_id,
        logger=logger, summary_writer=summary_writer
    )
    # Convert into tensors and build dataset
    all_input_ids_list = []
    all_input_mask_list = []
    all_segment_ids_list = []
    all_label_ids_list = []

    for f in features:
        all_input_ids_list.append(f.input_ids)
        all_input_mask_list.append(f.input_mask)
        all_segment_ids_list.append(f.segment_ids)
        all_label_ids_list.append(f.label_ids)

    all_input_ids = torch.tensor(all_input_ids_list, dtype=torch.long)
    all_input_mask = torch.tensor(all_input_mask_list, dtype=torch.long)
    all_segment_ids = torch.tensor(all_segment_ids_list, dtype=torch.long)
    all_label_ids = torch.tensor(all_label_ids_list, dtype=torch.long)

    dataset = TensorDataset(
        all_input_ids, all_input_mask, all_segment_ids, all_label_ids
    )
    if return_features_and_examples:
        return dataset, examples, features
    else:
        return dataset

In [6]:
def count_parameters(model):
    print(
        "Number of trainable parameters in the model are {}".format(
            sum(p.numel() for p in model.parameters() if p.requires_grad)
        )
    )


def get_result_matrix(
    loss, label_map, predictions_tensor, sentence_input_ids,
    labels_tensor, sep_token_id, give_lists=False
):
    """
    Get the results given predictions and labels
    """
#     label_to_not_consider_in_results = [
#         idx for label, idx in label_map.items()
#         if label in ["O"]
#     ]
    label2idx = {i: label for label, i in label_map.items()}

    out_label_list = [[] for _ in range(labels_tensor.shape[0])]
    preds_list = [[] for _ in range(predictions_tensor.shape[0])]

    for i in range(labels_tensor.shape[0]):
        for j in range(labels_tensor.shape[1]):
            if sentence_input_ids[i, j] == sep_token_id:
                break
            out_label_list[i].append(label2idx[labels_tensor[i][j]])
            preds_list[i].append(label2idx[predictions_tensor[i][j]])

    if give_lists:
        results = {
            "loss": loss,
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
            "out_label_list": out_label_list,
            "preds_list": preds_list
        }
    else:
        results = {
            "loss": loss,
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
            "out_label_list": out_label_list,
            "preds_list": preds_list
        }
    return results

In [7]:
def train_epoch(
    model, dataset, batch_size, label_map, max_grad_norm,
    optimizer, scheduler, device, sep_token_id, summary_writer=None
):
    tr_loss = 0.0

    preds = []
    out_label_ids = []
    input_ids_list = []

    model.train()
    sampler = RandomSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    print_stats_at_step = int(len(dataloader) / 20.0)
    epoch_iterator = tqdm(dataloader)
    for step, batch in enumerate(epoch_iterator):
        model.zero_grad()
        batch = tuple(t.to(device) for t in batch)
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "labels": batch[3]
        }
        # getting outputs
        logits, inputs["labels"], loss = model(**inputs)

        # propagating loss backwards and scheduler and opt. steps
        loss.backward()
        step_loss = loss.item()
        tr_loss += step_loss

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()

        if summary_writer:
            summary_writer.add_scaler("Loss/train", step_loss)
            summary_writer.add_scaler("LR/train", scheduler.get_lr()[0])


        # appending predictions and labels to list
        # for calculation of result
#         if preds is None:
        preds.append(logits.detach().cpu().numpy())
        out_label_ids.append(inputs["labels"].detach().cpu().numpy())
        input_ids_list.append(inputs["input_ids"][:, 1:].detach().cpu().numpy())
#         else:
#             preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
#             out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
#             input_ids_list = np.append(
#                     input_ids_list,
#                     inputs["input_ids"][:, 1:].detach().cpu().numpy(),
#                     axis=0
#                 )

        if step % print_stats_at_step == 0:
#             temp_results = get_result_matrix(
#                     tr_loss / (step + 1), label_map, preds, input_ids_list,
#                     out_label_ids, sep_token_id, give_lists=False
#                 )
            epoch_iterator.set_description(
                f'Tr Iter: {step+1}| step_loss: {step_loss: .3f}' #| avg_tr_f1: {temp_results["f1"]: .3f}'
            )

    preds = np.vstack(preds)
    input_ids_list = np.vstack(input_ids_list)
    out_label_ids = np.vstack(out_label_ids)
    epoch_loss = tr_loss / len(dataloader)
    results = get_result_matrix(
        epoch_loss, label_map, preds, input_ids_list,
        out_label_ids, sep_token_id, give_lists=False
    )

    if summary_writer:
        summary_writer.add_scaler("F1_epoch/train", results["f1"])
        summary_writer.add_scaler("Precision_epoch/train", results["precision"])
        summary_writer.add_scaler("Recall_epoch/train", results["recall"])

    return results


def eval_epoch(
    model, dataset, batch_size, label_map, device, sep_token_id,
    summary_writer=None, give_lists=False
):
    eval_loss = 0.0
    preds = []
    out_label_ids = []
    input_ids_list = []

    model.eval()
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    print_stats_at_step = int(len(dataloader) / 10.0)
    epoch_iterator = tqdm(dataloader)
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3]
            }
            # getting outputs
            logits, inputs["labels"], loss = model(**inputs)

            # propagating loss backwards and scheduler and opt. steps
            step_loss = loss.item()
            eval_loss += step_loss

            # appending predictions and labels to list
            # for calculation of result
#             if preds is None:
            preds.append(logits.detach().cpu().numpy())
            out_label_ids.append(inputs["labels"].detach().cpu().numpy())
            input_ids_list.append(inputs["input_ids"][:, 1:].detach().cpu().numpy())
#             else:
#                 preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
#                 out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
#                 input_ids_list = np.append(
#                     input_ids_list,
#                     inputs["input_ids"][:, 1:].detach().cpu().numpy(),
#                     axis=0
#                 )

            if step % print_stats_at_step == 0:
#                 temp_results = get_result_matrix(
#                     eval_loss / (step + 1), label_map, preds, input_ids_list,
#                     out_label_ids, sep_token_id, give_lists=False
#                 )
                epoch_iterator.set_description(
                    f'Eval Iter: {step+1}| step_loss: {step_loss: .3f}'
                )
    preds = np.vstack(preds)
    input_ids_list = np.vstack(input_ids_list)
    out_label_ids = np.vstack(out_label_ids)
    epoch_loss = eval_loss / len(dataloader)
    results = get_result_matrix(
        eval_loss / (step + 1), label_map, preds, input_ids_list,
        out_label_ids, sep_token_id, give_lists=False
    )

    if summary_writer:
        summary_writer.add_scaler("Loss/eval", epoch_loss)
        summary_writer.add_scaler("F1_epoch/eval", results["f1"])
        summary_writer.add_scaler("Precision_epoch/eval", results["precision"])
        summary_writer.add_scaler("Recall_epoch/eval", results["recall"])

    return results


# below functions are helpful in Inferencing
def predictions_from_model(model, tokenizer, dataset, batch_size, label2idx, device):
    pred_logits = None
    input_ids_list = None
    model.eval()
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    epoch_iterator = tqdm(dataloader, total=len(dataloader))
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": None
            }
            # getting outputs
            logits, _, _ = model(**inputs)

            # appending predictions and labels to list
            if pred_logits is None:
                pred_logits = logits.detach().cpu().numpy()
                input_ids_list = inputs["input_ids"][:, 1:].detach().cpu().numpy()
            else:
                pred_logits = np.append(pred_logits, logits.detach().cpu().numpy(), axis=0)
                input_ids_list = np.append(
                    input_ids_list,
                    inputs["input_ids"][:, 1:].detach().cpu().numpy(),
                    axis=0
                )

    idx2label = {i: label for label, i in label2idx.items()}
    prediction_labels  = []
    for sentence_label_logits, sentence_input_ids in zip(pred_logits, input_ids_list):
        temp = []
        for i, (p, w) in enumerate(zip(sentence_label_logits, sentence_input_ids)):
            if w == tokenizer.sep_token_id:
                break
            temp.append(idx2label[p])
        prediction_labels.append(temp)
    return prediction_labels


def align_predicted_labels_with_original_sentence_tokens(predicted_labels, examples, features, max_seq_length, num_special_tokens):
    """The label_predictions out of the model is according to the tokens (that we get after tokenizing every word using tokenizer).
    We need to align the predictions with the original words of the sentence.
    """
    aligned_predicted_labels = []
    for idx, (feature, p_l_s) in enumerate(zip(features, predicted_labels)):
        # print(idx)
        temp = []
        for i in range(len(feature.orig_to_token_index)):
            token_idx = feature.orig_to_token_index[i]
            if token_idx < (max_seq_length - num_special_tokens):
                temp.append(p_l_s[token_idx])
            else:
                temp.append("O")
        aligned_predicted_labels.append(temp)

    return aligned_predicted_labels, [ex.labels for ex in examples]


def convert_to_ents(tokens, tags):
    start_offset = None
    end_offset = None
    ent_type = None

    text = " ".join(tokens)
    entities = []
    start_char_offset = 0
    for offset, (token, tag) in enumerate(zip(tokens, tags)):
        token_tag = tag
        if token_tag == "O":
            if ent_type is not None and start_offset is not None:
                end_offset = offset - 1
                entity = {
                    "type": ent_type,
                    "entity": " ".join(tokens[start_offset: end_offset + 1]),
                    "start_offset": start_char_offset,
                    "end_offset": start_char_offset + len(" ".join(tokens[start_offset: end_offset + 1]))
                }
                entities.append(entity)
                start_char_offset += len(" ".join(tokens[start_offset: end_offset + 2])) + 1
                start_offset = None
                end_offset = None
                ent_type = None
            else:
                start_char_offset += len(token) + 1
        elif ent_type is None:
            ent_type = token_tag[2:]
            start_offset = offset
        elif ent_type != token_tag[2:]:
            end_offset = offset - 1
            entity = {
                "type": ent_type,
                "entity": " ".join(tokens[start_offset: end_offset + 1]),
                "start_offset": start_char_offset,
                "end_offset": start_char_offset + len(" ".join(tokens[start_offset: end_offset + 1]))
            }
            entities.append(entity)
            # start of a new entity
            ent_type = token_tag[2:]
            start_offset = offset
            end_offset = None

    # catches an entity that foes up untill the last token
    if ent_type and start_offset is not None and end_offset is not None:
        entity = {
            "type": ent_type,
            "entity": " ".join(tokens[start_offset:]),
            "start_offset": start_char_offset,
            "end_offset": start_char_offset + len(" ".join(tokens[start_offset:]))
        }
        entities.append(entity)
    return text, entities

# Building BERT models

In [8]:
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from torchcrf import CRF


class BertCrfForNER(BertModel):
    """
    This class inherits functionality from huggingface BertModel.
    It applies a crf layer on the Bert outputs.
    """
    def __init__(self, config, pad_idx, sep_idx, num_labels):
        """Inititalization
        Args:
            config (TYPE): model config flie (similar to bert_config.json)
            num_labels : total number of layers using the bio format
            pad_idx (TYPE): pad_idx of the tokenizer
            device (TYPE): torch.device()
        """
        super(BertCrfForNER, self).__init__(config)
        self.num_labels = num_labels
        self.pad_idx = pad_idx
        self.sep_idx = sep_idx

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.crf_layer = CRF(self.num_labels, batch_first=True)
        self.linear = nn.Linear(config.hidden_size, self.num_labels)
        self.init_weights()

    def create_mask_for_crf(self, inp):
        """Creates a mask for the feeding to crf layer.
           Mask <PAD> and <SEP> token positions
        Args:
            inp (TYPE): input given to bert layer
        """

        mask = (inp != self.pad_idx) & (inp != self.sep_idx)
        # mask = [seq_len, batch_size]

        return mask

    def forward(
        self, input_ids, attention_mask=None, token_type_ids=None,
        position_ids=None, head_mask=None, labels=None
    ):
        """Forwar propagate.
        Args:
            input_ids (TYPE): bert input ids
            attention_mask (None, optional): attention mask for bert
            token_type_ids (None, optional): token type ids for bert
            position_ids (None, optional): position ids for bert
            head_mask (None, optional): head mask for bert
            labels (None, optional): labels required while training crf
        """
        # getting outputs from Bert
        outputs = self.bert(
            input_ids, attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask
        )
        # taking tokens embeddings from the output
        sequence_output = outputs[0]
        # sequence_ouput = [batch_size, seq_len, hidden_size]

        logits = self.linear(sequence_output)
        # logits = [batch_size, seq_len, num_labels]

        # removing cls token
        logits = logits[:, 1:, :]
        if labels is not None:
            labels = labels[:, 1:] # check whether labels include the cls token too or not
        input_ids = input_ids[:, 1:]

        mask = self.create_mask_for_crf(input_ids)
        if labels is not None:
            loss = self.crf_layer(
                logits, labels, mask=mask
            ) * torch.tensor(-1, device=self.device)
        else:
            loss = None
        # this is the crf loss

        out = self.crf_layer.decode(logits)
        out = torch.tensor(out, dtype=torch.long, device=self.device)

        # out = [batch_size, seq_length]
        return out, labels, loss


In [9]:
class BertForTokenClassification(BertModel):
    """
    Simply doing token classification over bert outputs.
    """
    def __init__(self, config, num_labels, classification_layer_sizes=[]):
        super(BertForTokenClassification, self).__init__(config)
        self.num_labels = num_labels
        self.dropout_layer = nn.Dropout(config.hidden_dropout_prob)
        self.bert = BertModel(config)
        self.input_layer_sizes = [config.hidden_size] + classification_layer_sizes
        self.output_layer_size = classification_layer_sizes + [self.num_labels]
        self.classification_module = nn.ModuleList(
            nn.Linear(inp, out)
            for inp, out in zip(self.input_layer_sizes, self.output_layer_size)
        )
        self.num_linear_layer = len(classification_layer_sizes) + 1
        self.init_weights()

    def forward(
        self, input_ids, attention_mask=None, token_type_ids=None,
        position_ids=None, head_mask=None, labels=None
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask
        )

        logits = outputs[0]
        for layer_idx, layer in enumerate(self.classification_module):
            if layer_idx + 1 != self.num_linear_layer:
                logits = self.dropout_layer(F.relu(layer(logits)))
            else:
                logits = layer(logits)

        # escaping cls token
        logits = logits[:, 1:, :].contiguous()
        if labels is not None:
            labels = labels[:, 1:].contiguous()
        input_ids = input_ids[:, 1:].contiguous()
        attention_mask = attention_mask[:, 1:].contiguous()

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss =loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        else:
            loss = None

        softs, out = torch.max(logits, axis=2)
        return out, labels, loss

In [10]:
class BertLstmCrf(BertModel):
    """On the outputs of Bert there is a LSTM layer.
    On top of the LSTM there is a  CRF layer.
    """
    def __init__(
        self, config, pad_idx, lstm_hidden_dim,
        num_lstm_layers, bidirectional, num_labels
    ):
        super(BertLstmCrf, self).__init__(config)
        self.dropout_prob = config.hidden_dropout_prob
        self.pad_idx = pad_idx
        self.lstm_hidden_dim = lstm_hidden_dim
        self.num_lstm_layers = num_lstm_layers
        self.bidirectional = bidirectional
        self.num_labels = num_labels

        self.bert = BertModel(config)

        if self.num_lstm_layers > 1:
            self.lstm = nn.LSTM(
                input_size=config.hidden_size, hidden_size=self.lstm_hidden_dim,
                num_layers=self.num_lstm_layers, bidirectional=self.bidirectional,
                dropout=self.dropout_prob, batch_first=True
            )
        else:
            self.lstm = nn.LSTM(
                input_size=config.hidden_size, hidden_size=self.lstm_hidden_dim,
                num_layers=self.num_lstm_layers, bidirectional=self.bidirectional,
                batch_first=True
            )
        if self.bidirectional is True:
            self.linear = nn.Linear(self.lstm_hidden_dim*2, self.num_labels)
        else:
            self.linear = nn.Linear(self.lstm_hidden_dim, self.num_labels)

        self.crf_layer = CRF(self.num_labels, batch_first=True)
        self.dropout_layer = nn.Dropout(self.dropout_prob)

        self.init_weights()

    def create_mask_for_crf(self, inp):
        """Creates a mask for the feesing to crf layer.
        Args:
            inp (TYPE): input given to bert layer
        """
        mask = (inp != self.pad_idx) & (inp != self.sep_idx)
        # mask = [seq_len, batch_size]

        return mask

    def forward(
        self, input_ids, attention_mask=None, token_type_ids=None,
        position_ids=None, head_mask=None, labels=None
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask
        )
        sequence_output = outputs[0]

        lstm_out, (hidden, cell) = self.lstm(sequence_output)
        logits = self.linear(self.dropout_layer(lstm_out))

        # removing cls token
        logits = logits[:, 1:, :]
        if labels is not None:
            labels = labels[:, 1:]
        input_ids = input_ids[:, 1:]

        # creating mask for crf
        mask = self.create_mask_for_crf(input_ids)

        # crf part
        if labels is not None:
            loss = self.crf_layer(logits, labels, mask=mask) * torch.tensor(-1, device=self.device)
        else:
            loss = None

        out = self.crf_layer.decode(logits)
        out = torch.tensor(out, dtype=torch.long, device=self.device)
        # out = [batch_Size, seq_len]
        return out, labels, loss

# Get scibert models
https://github.com/allenai/scibert#pytorch-models

# Train the model

In [11]:
import os
import json
import logging
import fire
import sys
import torch
from transformers import BertTokenizer, BertConfig
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

from seqeval.metrics import classification_report

In [12]:
DEVICE = torch.device("cpu")

In [33]:
def train_ner_model(model_config_path, data_dir, logger_file_dir=None, labels_file=None):
    # loading model config path
    if os.path.exists(model_config_path):
        with open(model_config_path, "r", encoding="utf-8") as reader:
            text = reader.read()
        model_config_dict = json.loads(text)
    else:
        print("model_config_path doesn't exist.")
        sys.exit()

    if os.path.exists(model_config_dict["final_model_saving_dir"]):
        output_model_file = model_config_dict["final_model_saving_dir"] + "pytorch_model.bin"
        output_config_file = model_config_dict["final_model_saving_dir"] + "bert_config.json"
        output_vocab_file = model_config_dict["final_model_saving_dir"] + "vocab.txt"
    else:
        print("model_saving_dir doesn't exist.")
        sys.exit()

    if os.path.exists(logger_file_dir):
        logging.basicConfig(
            filename=logger_file_dir + "logs.txt",
            filemode="w"
        )
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
    else:
        print("logger_file_path doesn't exist.")
        sys.exit()

    if os.path.exists(labels_file):
        print("Labels file exist")
    else:
        print("labels_file doesn't exist.")
        sys.exit()

    logger.info("Training configurations are given below ::")
    for key, val in model_config_dict.items():
        logger.info("{} == {}".format(key, val))

    logger.info("Started training model :::::::::::::::::::::")

    bert_config = BertConfig.from_json_file(model_config_dict["bert_config_path"])
    bert_tokenizer = BertTokenizer.from_pretrained(
        model_config_dict["bert_vocab_path"],
        config=bert_config,
        do_lower_case=model_config_dict["tokenizer_do_lower_case"]
    )
    # saving confgi and tokenizer
    bert_tokenizer.save_vocabulary(output_vocab_file)
    bert_config.to_json_file(output_config_file)

    labels = get_labels(labels_file)
    logger.info("Labels for Ner are: {}".format(labels))

    label2idx = {l: i for i, l in enumerate(labels)}


    # preparing training data
    train_dataset = load_and_cache_examples(
        data_dir=data_dir,
        max_seq_length=model_config_dict["max_seq_length"],
        tokenizer=bert_tokenizer,
        label_map=label2idx,
        pad_token_label_id=label2idx["O"],
        mode="train", logger=logger
    )
    # preparing eval data
    eval_dataset = load_and_cache_examples(
        data_dir=data_dir,
        max_seq_length=model_config_dict["max_seq_length"],
        tokenizer=bert_tokenizer,
        label_map=label2idx,
        pad_token_label_id=label2idx["O"],
        mode="dev", logger=logger
    )
    logger.info("Training data and eval data loaded successfully.")
    # Change model_type as required
    if model_config_dict["model_type"] == "crf":
        model = BertCrfForNER.from_pretrained(
            model_config_dict["bert_model_path"],
            config=bert_config,
            pad_idx=bert_tokenizer.pad_token_id,
            sep_idx=bert_tokenizer.sep_token_id,
            num_labels=len(labels)
        )

    logger.info("{} model loaded successfully.".format(model_config_dict["model_type"]))

    # checking whether to finetune or not
    if model_config_dict["finetune"] == True:
        logger.info("Finetuning bert.")
    else:
        for param in list(model.bert.parameters()):
            param.requires_grad = False
        logger.info("Freezing Berts weights.")

    # preparing optimizer and scheduler
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0
        }
    ]
    # total optimizer steps
    t_total = int((len(train_dataset) / model_config_dict["train_batch_size"]) * model_config_dict["num_epochs"])
    logger.info("t_total : {}".format(t_total))

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=model_config_dict["learning_rate"],
        eps=model_config_dict["epsilon"]
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=model_config_dict["warmup_steps"],
        num_training_steps=t_total
    )
    logger.info("{}".format(count_parameters))

    model.to(DEVICE)

    best_eval_f1 = 0.0
    for epoch in range(model_config_dict["num_epochs"]):
        train_result = train_epoch(
            model=model, dataset=train_dataset,
            batch_size=model_config_dict["train_batch_size"],
            label_map=label2idx,
            max_grad_norm=model_config_dict["max_grad_norm"],
            optimizer=optimizer, scheduler=scheduler, device=DEVICE,
            sep_token_id=bert_tokenizer.sep_token_id
        )
        eval_result = eval_epoch(
            model=model, dataset=eval_dataset,
            batch_size=model_config_dict["validation_batch_size"],
            label_map=label2idx, device=DEVICE, sep_token_id=bert_tokenizer.sep_token_id,
            give_lists=False
        )
        print(f'Epoch: {epoch + 1}')
        print(f'Train Loss: {train_result["loss"]: .4f}| Train F1: {train_result["f1"]: .4f}')
        print(f'Eval Loss: {eval_result["loss"]: .4f}| Eval F1: {eval_result["f1"]: .4f}')
        logger.info(f'Epoch: {epoch + 1}')
        logger.info(f'Train Loss: {train_result["loss"]: .4f}| Train F1: {train_result["f1"]: .4f}')
        logger.info(f'Eval Loss: {eval_result["loss"]: .4f}| Eval F1: {eval_result["f1"]: .4f}')

        if best_eval_f1 < eval_result["f1"]:
            best_eval_f1 = eval_result["f1"]
            # saving model to disk
            model_to_save = model.module if hasattr(model, "module") else model
            torch.save(model_to_save.state_dict(), output_model_file)
            print("Saved a better model.")
            logger.info("Saved a beter model")
            del model_to_save

    # loading the best model and test results
    model.load_state_dict(torch.load(output_model_file))
    logger.info("Loaded best model successfully.")

    test_dataset, test_examples, test_features = load_and_cache_examples(
        data_dir=data_dir,
        max_seq_length=model_config_dict["max_seq_length"],
        tokenizer=bert_tokenizer,
        label_map=label2idx,
        pad_token_label_id=label2idx["O"],
        mode="test", logger=logger,
        return_features_and_examples=True
    )
    logger.info("Test data loaded successfully.")

    test_label_predictions = predictions_from_model(
        model=model, tokenizer=bert_tokenizer,
        dataset=test_dataset,
        batch_size=model_config_dict["validation_batch_size"],
        label2idx=label2idx, device=DEVICE
    )
    # restructure test_label_predictions with real labels
    aligned_predicted_labels, true_labels = align_predicted_labels_with_original_sentence_tokens(
        test_label_predictions, test_examples, test_features, max_seq_length=model_config_dict["max_seq_length"],
        num_special_tokens=model_config_dict["num_special_tokens"]
    )
    print("Test Results classification report...")
    print(classification_report(true_labels, aligned_predicted_labels))
    return aligned_predicted_labels, true_labels

In [36]:
train_ner_model('/Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/crf_ner_config.json', '/Users/sdeshpande/Desktop/bioinformatices/MTL-Bioinformatics-2016/data/NCBI-disease-IOB/', '/Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/log_dir/', '/Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/NCBI_disease_label.txt')

Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated
Labels file exist
Creating features from dataset file at /Users/sdeshpande/Desktop/bioinformatices/MTL-Bioinformatics-2016/data/NCBI-disease-IOB/
Creating features from dataset file at /Users/sdeshpande/Desktop/bioinformatices/MTL-Bioinformatics-2016/data/NCBI-disease-IOB/
Some weights of the model checkpoint at /Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/scibert_scivocab_uncased/pytorch_model.bin were not used when initializing BertCrfForNER: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertCrfForNER from the checkpoint of a model trained on another task or with another architecture (e.g


   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O'],
  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B-Disease',
   'O',
   'O',
   'O',
   'O',
   'O'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B-Disease',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B-Disease',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
  

# Model interpretation

In [38]:
import os
import json
import logging
from pprint import pprint
import sys

import torch

from transformers import BertTokenizer, BertConfig


DEVICE = torch.device("cpu")
print("Device Being used as {} \n".format(DEVICE))


logging.basicConfig(
    filename="inference_logs.txt",
    filemode="w"
)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

class NERTagger:
    def __init__(
        self, labels_file,
        model_config_path, device
    ):
        self.model_config_path = model_config_path
        self.labels_file = labels_file
        self.device = device
        if os.path.exists(self.model_config_path):
            with open(self.model_config_path, "r", encoding="utf-8") as reader:
                text = reader.read()
            self.model_config_dict = json.loads(text)
        else:
            print("model_config_path doesn't exist.")
            sys.exit()

        if os.path.exists(self.model_config_dict["final_model_saving_dir"]):
            self.model_file = self.model_config_dict["final_model_saving_dir"] + "pytorch_model.bin"
            self.config_file = self.model_config_dict["final_model_saving_dir"] + "bert_config.json"
            self.vocab_file = self.model_config_dict["final_model_saving_dir"] + "vocab.txt"
        else:
            print("model_saving_dir doesn't exist.")
            sys.exit()
        if os.path.exists(self.labels_file):
            print("Labels file exist")
        else:
            print("labels_file doesn't exist.")
            sys.exit()

        self.bert_config = BertConfig.from_json_file(self.config_file)
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            self.vocab_file,
            config=self.bert_config,
            do_lower_case=self.model_config_dict["tokenizer_do_lower_case"]
        )
        self.labels = get_labels(self.labels_file)
        self.label2idx = {l: i for i, l in enumerate(self.labels)}


        if self.model_config_dict["model_type"] == "crf":
            self.model = BertCrfForNER.from_pretrained(
                self.model_file,
                config=self.bert_config,
                pad_idx=self.bert_tokenizer.pad_token_id,
                sep_idx=self.bert_tokenizer.sep_token_id,
                num_labels=len(self.labels)
            )
        elif self.model_config_dict["model_type"] == "token_classification":
            self.model = BertForTokenClassification.from_pretrained(
                self.model_file,
                config=self.bert_config,
                num_labels=len(self.labels),
                classification_layer_sizes=self.model_config_dict["classification_layer_sizes"]
            )
        elif  self.model_config_dict["model_type"] == "lstm_crf":
            self.model = BertLstmCrf.from_pretrained(
                self.model_file,
                config=self.bert_config,
                num_labels=len(self.labels),
                pad_idx=self.bert_tokenizer.pad_token_id,
                lstm_hidden_dim=self.model_config_dict["lstm_hidden_dim"],
                num_lstm_layers=self.model_config_dict["num_lstm_layers"],
                bidirectional=self.model_config_dict["bidirectional"]
            )
        self.model.to(self.device)
        print("Model loaded successfully from the config provided.")

    def tag_sentences(self, sentence_list, logger, batch_size):
        dataset, examples, features = load_and_cache_examples(
            max_seq_length=self.model_config_dict["max_seq_length"],
            tokenizer=self.bert_tokenizer,
            label_map=self.label2idx,
            pad_token_label_id=self.label2idx["O"],
            mode="inference", data_dir=None,
            logger=logger, sentence_list=sentence_list,
            return_features_and_examples=True
        )

        label_predictions = predictions_from_model(
            model=self.model, tokenizer=self.bert_tokenizer,
            dataset=dataset, batch_size=batch_size,
            label2idx=self.label2idx, device=self.device
        )
        # restructure test_label_predictions with real labels
        aligned_predicted_labels, _ = align_predicted_labels_with_original_sentence_tokens(
            label_predictions, examples, features,
            max_seq_length=self.model_config_dict["max_seq_length"],
            num_special_tokens=self.model_config_dict["num_special_tokens"]
        )
        results = []
        for label_tags, example in zip(aligned_predicted_labels, examples):
            results.append(
                convert_to_ents(example.words, label_tags)
            )
        return results


if __name__ == "__main__":
    sentence_list = [
        "Number of glucocorticoid receptors in lymphocytes and their sensitivity to hormone action .",
        "The study demonstrated a decreased level of glucocorticoid receptors ( GR ) in peripheral blood lymphocytes from hypercholesterolemic subjects , and an elevated level in patients with acute myocardial infarction .",
        "In the lymphocytes with a high GR number , dexamethasone inhibited [ 3H ] -thymidine and [ 3H ] -acetate incorporation into DNA and cholesterol , respectively , in the same manner as in the control cells .",
        "On the other hand , a decreased GR number resulted in a less efficient dexamethasone inhibition of the incorporation of labeled compounds .",
        "hese data showed that the sensitivity of lymphocytes to glucocorticoids changed only with a decrease of GR level .",
        "Treatment with I-hydroxyvitamin D3 ( 1-1.5 mg daily , within 4 weeks ) led to normalization of total and ionized form of Ca2+ and of 25 ( OH ) D , but did not affect the PTH content in blood .",
        "The data obtained suggest that under conditions of glomerulonephritis only high content of receptors to 1.25 ( OH ) 2D3 in lymphocytes enabled to perform the cell response to the hormone effect .",
        "To investigate whether the tumor expression of beta-2-microglobulin ( beta 2-M ) could serve as a marker of tumor biologic behavior , the authors studied specimens of breast carcinomas from 60 consecutive female patients .",
        "Presence of beta 2-M was analyzed by immunohistochemistry .",
        "I love data science",
        "Humira showed better results than Cimzia for treating psoriasis .",
        "Important advancements in the treatment of non - small cell lung cancer (NSCLC) have been achieved over the past two decades, increasing our understanding of the disease biology and mechanisms of tumour progression, and advancing early detection and multimodal care .",
        "The use of small molecule tyrosine kinase inhibitors and immunotherapy has led to unprecedented survival benefits in selected patients .",
        "However, the overall cure and survival rates for NSCLC remain low, particularly in metastatic disease .",
        "Therefore, continued research into new drugs and combination therapies is required to expand the clinical benefit to a broader patient population and to improve outcomes in NSCLC .",
        "The non-small cell lung cancer immune contexture. A major determinant of tumor characteristics and patient outcome ."
    ]

    tagger = NERTagger(
        labels_file="/Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/NCBI_disease_label.txt",
        model_config_path="/Users/sdeshpande/Desktop/text_analysis_scripts/biomedical_bert_ner/crf_ner_config.json",
        device=DEVICE
    )
    pprint(tagger.tag_sentences(sentence_list, logger=logger, batch_size=2))

Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated
Device Being used as cpu 

Labels file exist
  0%|          | 0/8 [00:00<?, ?it/s]Model loaded successfully from the config provided.
100%|██████████| 8/8 [00:03<00:00,  2.32it/s][('Number of glucocorticoid receptors in lymphocytes and their sensitivity to '
  'hormone action .',
  [{'end_offset': 49,
    'entity': 'lymphocytes',
    'start_offset': 38,
    'type': 'Disease'}]),
 ('The study demonstrated a decreased level of glucocorticoid receptors ( GR ) '
  'in peripheral blood lymphocytes from hypercholesterolemic subjects , and an '
  'elevated level in patients with acute myocardial infarction .',
  [{'end_offset': 107,
    'entity': 'lymphocytes',
    'start_offset': 96,
    'type': 'Disease'},
   {'end_offset': 160,
    'entity': 'elevated',
    'start_offset': 152,
    'type': 'Disease'}]),
 ('In the lymphocytes with a high GR number , dexamethasone inhibited [ 3H ] '
  '-thymidine and 

In [18]:
labels

['O', 'B-Disease', 'I-Disease']