# Entity extraction using BERT

## Import everything important

In [1]:
import joblib
import torch
import torch.nn as nn
import transformers

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn import model_selection

from tqdm import tqdm
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup



In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Not authenticated.  Copy a key from https://app.wandb.ai/authorize


API Key:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Some config

In [4]:
class config:
    MAX_LEN = 128
    TRAIN_BATCH_SIZE = 64 # 32 for 5.8GB
    VALID_BATCH_SIZE = 16 # 8 for 5.8GB
    EPOCHS = 5
    BASE_MODEL_PATH = "../input/bert-base-uncased/"
    MODEL_PATH = "model.bin"
    TRAINING_FILE = "../input/entity-annotated-corpus/ner_dataset.csv"
    TOKENIZER = transformers.BertTokenizer.from_pretrained(
        BASE_MODEL_PATH,
        do_lower_case=True
    )

In [11]:
import os
os.environ['WANDB_MODE'] = 'offline'

In [12]:
wandb.init(project="entity-extraction-by-bert", config={
    "epochs": config.EPOCHS,
    "train_batch_size": config.TRAIN_BATCH_SIZE,
    "valid_batch_size": config.VALID_BATCH_SIZE,
    "max_len": config.MAX_LEN,
    "base_model_path": config.BASE_MODEL_PATH
})

CommError: To use W&B in kaggle you must enable internet in the settings panel on the right.

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop. See /kaggle/working/wandb/debug.log for full traceback.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop. See /kaggle/working/wandb/debug.log for full traceback.


## Dataset

In [13]:
class EntityDataset:
    def __init__(self, texts, pos, tags):
        self.texts = texts
        self.pos = pos
        self.tags = tags
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        text = self.texts[item]
        pos = self.pos[item]
        tags = self.tags[item]

        ids = []
        target_pos = []
        target_tag =[]

        for i, s in enumerate(text):
            inputs = config.TOKENIZER.encode(
                s,
                add_special_tokens=False
            )
            # abhishek: ab ##hi ##sh ##ek
            input_len = len(inputs)
            ids.extend(inputs)
            target_pos.extend([pos[i]] * input_len)
            target_tag.extend([tags[i]] * input_len)

        ids = ids[:config.MAX_LEN - 2]
        target_pos = target_pos[:config.MAX_LEN - 2]
        target_tag = target_tag[:config.MAX_LEN - 2]

        ids = [101] + ids + [102]
        target_pos = [0] + target_pos + [0]
        target_tag = [0] + target_tag + [0]

        mask = [1] * len(ids)
        token_type_ids = [0] * len(ids)

        padding_len = config.MAX_LEN - len(ids)

        ids = ids + ([0] * padding_len)
        mask = mask + ([0] * padding_len)
        token_type_ids = token_type_ids + ([0] * padding_len)
        target_pos = target_pos + ([0] * padding_len)
        target_tag = target_tag + ([0] * padding_len)

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "target_pos": torch.tensor(target_pos, dtype=torch.long),
            "target_tag": torch.tensor(target_tag, dtype=torch.long),
        }

## Training and evaluation functions

In [14]:
def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    final_loss = 0
    for data in tqdm(data_loader, total=len(data_loader)):
        for k, v in data.items():
            data[k] = v.to(device)
        optimizer.zero_grad()
        _, _, loss = model(**data)
        loss.backward()
        optimizer.step()
        scheduler.step()
        final_loss += loss.item()
        
        # Clear cached memory
        torch.cuda.empty_cache()
    return final_loss / len(data_loader)


In [15]:
from sklearn.metrics import classification_report

def eval_fn(data_loader, model, device):
    model.eval()
    final_loss = 0
    all_preds_tag = []
    all_targets_tag = []
    all_preds_pos = []
    all_targets_pos = []
    
    for data in tqdm(data_loader, total=len(data_loader)):
        for k, v in data.items():
            data[k] = v.to(device)
        
        with torch.no_grad():
            tag, pos, loss = model(**data)
        
        final_loss += loss.item()
        
        # Get predictions and targets for tags and pos
        preds_tag = tag.argmax(2).cpu().numpy()
        targets_tag = data["target_tag"].cpu().numpy()
        preds_pos = pos.argmax(2).cpu().numpy()
        targets_pos = data["target_pos"].cpu().numpy()
        
        # Flatten and filter out padding tokens
        for i in range(targets_tag.shape[0]):  # iterate over batch
            all_preds_tag.extend(preds_tag[i][data["mask"][i].cpu().numpy() == 1])
            all_targets_tag.extend(targets_tag[i][data["mask"][i].cpu().numpy() == 1])
            all_preds_pos.extend(preds_pos[i][data["mask"][i].cpu().numpy() == 1])
            all_targets_pos.extend(targets_pos[i][data["mask"][i].cpu().numpy() == 1])
    
    avg_loss = final_loss / len(data_loader)
    
    # Define labels to ensure consistency
    tag_labels = list(range(len(enc_tag.classes_)))
    pos_labels = list(range(len(enc_pos.classes_)))
    
    # Classification report for tags
    tag_report = classification_report(all_targets_tag, all_preds_tag, labels=tag_labels, target_names=enc_tag.classes_, zero_division=0)
    pos_report = classification_report(all_targets_pos, all_preds_pos, labels=pos_labels, target_names=enc_pos.classes_, zero_division=0)
    
    print("Tagging Classification Report:\n", tag_report)
    print("POS Classification Report:\n", pos_report)
    
    return avg_loss

## Loss function and model

In [16]:
def loss_fn(output, target, mask, num_labels):
    lfn = nn.CrossEntropyLoss()
    active_loss = mask.view(-1) == 1
    active_logits = output.view(-1, num_labels)
    active_labels = torch.where(
        active_loss,
        target.view(-1),
        torch.tensor(lfn.ignore_index).type_as(target)
    )
    loss = lfn(active_logits, active_labels)
    return loss


class EntityModel(nn.Module):
    def __init__(self, num_tag, num_pos):
        super(EntityModel, self).__init__()
        self.num_tag = num_tag
        self.num_pos = num_pos
        self.bert = transformers.BertModel.from_pretrained(
            config.BASE_MODEL_PATH
        )
        self.bert_drop_1 = nn.Dropout(0.3)
        self.bert_drop_2 = nn.Dropout(0.3)
        self.out_tag = nn.Linear(768, self.num_tag)
        self.out_pos = nn.Linear(768, self.num_pos)
    
    def forward(
        self, 
        ids, 
        mask, 
        token_type_ids, 
        target_pos, 
        target_tag
    ):
        o1, _ = self.bert(
            ids, 
            attention_mask=mask, 
            token_type_ids=token_type_ids
        )

        bo_tag = self.bert_drop_1(o1)
        bo_pos = self.bert_drop_2(o1)

        tag = self.out_tag(bo_tag)
        pos = self.out_pos(bo_pos)

        loss_tag = loss_fn(tag, target_tag, mask, self.num_tag)
        loss_pos = loss_fn(pos, target_pos, mask, self.num_pos)

        loss = (loss_tag + loss_pos) / 2

        return tag, pos, loss

## Data processing

In [17]:
def process_data(data_path):
    df = pd.read_csv(data_path, encoding="latin-1")
    df.loc[:, "Sentence #"] = df["Sentence #"].fillna(method="ffill")

    enc_pos = preprocessing.LabelEncoder()
    enc_tag = preprocessing.LabelEncoder()

    df.loc[:, "POS"] = enc_pos.fit_transform(df["POS"])
    df.loc[:, "Tag"] = enc_tag.fit_transform(df["Tag"])

    sentences = df.groupby("Sentence #")["Word"].apply(list).values
    pos = df.groupby("Sentence #")["POS"].apply(list).values
    tag = df.groupby("Sentence #")["Tag"].apply(list).values
    return sentences, pos, tag, enc_pos, enc_tag

## Training

In [19]:
sentences, pos, tag, enc_pos, enc_tag = process_data(config.TRAINING_FILE)

meta_data = {
    "enc_pos": enc_pos,
    "enc_tag": enc_tag
}

joblib.dump(meta_data, "meta.bin")

num_pos = len(list(enc_pos.classes_))
num_tag = len(list(enc_tag.classes_))

(
    train_sentences,
    test_sentences,
    train_pos,
    test_pos,
    train_tag,
    test_tag
) = model_selection.train_test_split(
    sentences, 
    pos, 
    tag, 
    random_state=42, 
    test_size=0.1
)

train_dataset = EntityDataset(
    texts=train_sentences, pos=train_pos, tags=train_tag
)

train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.TRAIN_BATCH_SIZE, num_workers=4
)

valid_dataset = EntityDataset(
    texts=test_sentences, pos=test_pos, tags=test_tag
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1
)

device = torch.device("cuda")
model = EntityModel(num_tag=num_tag, num_pos=num_pos)
model.to(device)

param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
    {
        "params": [
            p for n, p in param_optimizer if not any(
                nd in n for nd in no_decay
            )
        ],
        "weight_decay": 0.001,
    },
    {
        "params": [
            p for n, p in param_optimizer if any(
                nd in n for nd in no_decay
            )
        ],
        "weight_decay": 0.0,
    },
]

num_train_steps = int(
    len(train_sentences) / config.TRAIN_BATCH_SIZE * config.EPOCHS
)
optimizer = AdamW(optimizer_parameters, lr=3e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0, 
    num_training_steps=num_train_steps
)

best_loss = np.inf
for epoch in range(config.EPOCHS):
    train_loss = train_fn(
        train_data_loader, 
        model, 
        optimizer, 
        device, 
        scheduler
    )
    test_loss = eval_fn(
        valid_data_loader,
        model,
        device
    )
    print(f"Train Loss = {train_loss} Valid Loss = {test_loss}")
    if test_loss < best_loss:
        torch.save(model.state_dict(), config.MODEL_PATH)
        best_loss = test_loss

100%|██████████| 675/675 [10:40<00:00,  1.05it/s]
100%|██████████| 300/300 [00:20<00:00, 14.85it/s]


Tagging Classification Report:
               precision    recall  f1-score   support

       B-art       1.00      0.99      1.00      9672
       B-eve       0.00      0.00      0.00        40
       B-geo       0.85      0.89      0.87      5908
       B-gpe       0.92      0.91      0.92      1735
       B-nat       0.00      0.00      0.00        24
       B-org       0.71      0.66      0.69      3429
       B-per       0.83      0.82      0.83      2625
       B-tim       0.88      0.84      0.86      2207
       I-art       0.00      0.00      0.00        43
       I-eve       0.00      0.00      0.00        42
       I-geo       0.77      0.70      0.74       861
       I-gpe       0.00      0.00      0.00        20
       I-nat       0.00      0.00      0.00         7
       I-org       0.65      0.62      0.63      2102
       I-per       0.82      0.94      0.87      3148
       I-tim       0.73      0.64      0.68       625
           O       0.99      0.99      0.99     9

100%|██████████| 675/675 [10:39<00:00,  1.05it/s]
100%|██████████| 300/300 [00:20<00:00, 14.92it/s]


Tagging Classification Report:
               precision    recall  f1-score   support

       B-art       1.00      0.99      1.00      9672
       B-eve       0.50      0.07      0.13        40
       B-geo       0.86      0.89      0.87      5908
       B-gpe       0.93      0.93      0.93      1735
       B-nat       0.00      0.00      0.00        24
       B-org       0.74      0.68      0.71      3429
       B-per       0.84      0.83      0.84      2625
       B-tim       0.89      0.85      0.87      2207
       I-art       0.00      0.00      0.00        43
       I-eve       0.00      0.00      0.00        42
       I-geo       0.77      0.73      0.75       861
       I-gpe       0.00      0.00      0.00        20
       I-nat       0.00      0.00      0.00         7
       I-org       0.66      0.67      0.67      2102
       I-per       0.82      0.94      0.88      3148
       I-tim       0.73      0.76      0.75       625
           O       0.99      0.99      0.99     9

100%|██████████| 675/675 [10:40<00:00,  1.05it/s]
100%|██████████| 300/300 [00:20<00:00, 14.77it/s]


Tagging Classification Report:
               precision    recall  f1-score   support

       B-art       1.00      0.99      0.99      9672
       B-eve       0.43      0.15      0.22        40
       B-geo       0.87      0.89      0.88      5908
       B-gpe       0.94      0.93      0.93      1735
       B-nat       0.12      0.04      0.06        24
       B-org       0.73      0.71      0.72      3429
       B-per       0.86      0.83      0.84      2625
       B-tim       0.89      0.85      0.87      2207
       I-art       0.00      0.00      0.00        43
       I-eve       0.00      0.00      0.00        42
       I-geo       0.78      0.75      0.76       861
       I-gpe       0.00      0.00      0.00        20
       I-nat       0.00      0.00      0.00         7
       I-org       0.66      0.72      0.69      2102
       I-per       0.84      0.93      0.88      3148
       I-tim       0.70      0.80      0.75       625
           O       0.99      0.99      0.99     9

100%|██████████| 675/675 [10:40<00:00,  1.05it/s]
100%|██████████| 300/300 [00:20<00:00, 14.83it/s]


Tagging Classification Report:
               precision    recall  f1-score   support

       B-art       1.00      0.99      0.99      9672
       B-eve       0.43      0.15      0.22        40
       B-geo       0.86      0.90      0.88      5908
       B-gpe       0.93      0.93      0.93      1735
       B-nat       0.17      0.04      0.07        24
       B-org       0.76      0.71      0.74      3429
       B-per       0.87      0.83      0.85      2625
       B-tim       0.88      0.87      0.88      2207
       I-art       0.00      0.00      0.00        43
       I-eve       0.00      0.00      0.00        42
       I-geo       0.76      0.77      0.76       861
       I-gpe       0.00      0.00      0.00        20
       I-nat       0.00      0.00      0.00         7
       I-org       0.69      0.72      0.70      2102
       I-per       0.85      0.94      0.89      3148
       I-tim       0.71      0.81      0.76       625
           O       0.99      0.99      0.99     9

100%|██████████| 675/675 [10:40<00:00,  1.05it/s]
100%|██████████| 300/300 [00:20<00:00, 14.83it/s]


Tagging Classification Report:
               precision    recall  f1-score   support

       B-art       1.00      0.99      0.99      9672
       B-eve       0.43      0.15      0.22        40
       B-geo       0.86      0.90      0.88      5908
       B-gpe       0.94      0.93      0.94      1735
       B-nat       0.20      0.04      0.07        24
       B-org       0.79      0.70      0.74      3429
       B-per       0.85      0.86      0.85      2625
       B-tim       0.89      0.87      0.88      2207
       I-art       0.00      0.00      0.00        43
       I-eve       0.00      0.00      0.00        42
       I-geo       0.77      0.78      0.77       861
       I-gpe       0.00      0.00      0.00        20
       I-nat       0.00      0.00      0.00         7
       I-org       0.72      0.68      0.70      2102
       I-per       0.85      0.93      0.89      3148
       I-tim       0.78      0.80      0.79       625
           O       0.99      0.99      0.99     9

## Inference

In [20]:
meta_data = joblib.load("meta.bin")
enc_pos = meta_data["enc_pos"]
enc_tag = meta_data["enc_tag"]

num_pos = len(list(enc_pos.classes_))
num_tag = len(list(enc_tag.classes_))

sentence = """
abhishek is going to india
"""
tokenized_sentence = config.TOKENIZER.encode(sentence)

sentence = sentence.split()
print(sentence)
print(tokenized_sentence)

test_dataset = EntityDataset(
    texts=[sentence], 
    pos=[[0] * len(sentence)], 
    tags=[[0] * len(sentence)]
)

device = torch.device("cuda")
model = EntityModel(num_tag=num_tag, num_pos=num_pos)
model.load_state_dict(torch.load(config.MODEL_PATH))
model.to(device)

with torch.no_grad():
    data = test_dataset[0]
    for k, v in data.items():
        data[k] = v.to(device).unsqueeze(0)
    tag, pos, _ = model(**data)

    print(
        enc_tag.inverse_transform(
            tag.argmax(2).cpu().numpy().reshape(-1)
        )[:len(tokenized_sentence)]
    )
    print(
        enc_pos.inverse_transform(
            pos.argmax(2).cpu().numpy().reshape(-1)
        )[:len(tokenized_sentence)]
    )

['abhishek', 'is', 'going', 'to', 'india']
[101, 11113, 24158, 5369, 2243, 2003, 2183, 2000, 2634, 102]
['B-art' 'B-per' 'B-per' 'B-per' 'B-per' 'O' 'O' 'O' 'B-geo' 'B-art']
['$' 'NNP' 'NNP' 'NNP' 'NNP' 'VBZ' 'VBG' 'TO' 'NNP' '$']


In [21]:
import shutil

# Define the directory path and the output zip file name
directory_path = '/kaggle/working/'
output_zip = '/kaggle/working/kaggle_working.zip'

# Compress the directory
shutil.make_archive('/kaggle/working/kaggle_working', 'zip', directory_path)

# Now, use the Kaggle environment's file download function to download it
from IPython.display import FileLink
FileLink(output_zip)

In [22]:
import shutil

# Define the directory path and the output zip file name in the 'output' folder
directory_path = '/kaggle/working/'
output_zip = '/kaggle/working/kaggle_working.zip'

# Compress the directory
shutil.make_archive('/kaggle/working/kaggle_working', 'zip', directory_path)

# Move the zip file to /kaggle/working which should be accessible for download
import os
os.rename('/kaggle/working/kaggle_working.zip', '/kaggle/working/kaggle_working.zip')

# Display link for download
from IPython.display import FileLink
FileLink('/kaggle/working/kaggle_working.zip')

In [23]:
import shutil

# Compress the directory and save it to the 'working' directory
shutil.make_archive('/kaggle/working/kaggle_working', 'zip', '/kaggle/working/')

'/kaggle/working/kaggle_working.zip'