In [24]:
import torch
import random
import numpy as np
import os

seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
from torch.utils.data import Dataset

categories = set()

class PeopleDaily(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)

    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f.read().split('\n\n')): # each line as one sentence
                if not line:
                    break
                sentence, labels = '', []
                for i, item in enumerate(line.split('\n')):
                    char, tag = item.split(' ')
                    sentence += char
                    if tag.startswith('B'):
                        labels.append([i, i, char, tag[2:]]) # Remove the B- or I-
                        categories.add(tag[2:])
                    elif tag.startswith('I'):
                        labels[-1][1] = i
                        labels[-1][2] += char
                Data[idx] = {
                    'sentence': sentence,
                    'labels': labels
                }
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [25]:
train_data = PeopleDaily('data/china-people-daily-ner-corpus/example.train')
valid_data = PeopleDaily('data/china-people-daily-ner-corpus/example.dev')
test_data = PeopleDaily('data/china-people-daily-ner-corpus/example.test')
print(train_data[0])

{'sentence': '海钓比赛地点在厦门与金门之间的海域。', 'labels': [[7, 8, '厦门', 'LOC'], [10, 11, '金门', 'LOC']]}


In [26]:
#create the label mapping
id2label = {0:'O'}
for c in list(sorted(categories)):
    id2label[len(id2label)] = f"B-{c}"
    id2label[len(id2label)] = f"I-{c}"
label2id = {v: k for k, v in id2label.items()}

print(id2label)
print(label2id)

{0: 'O', 1: 'B-LOC', 2: 'I-LOC', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-PER', 6: 'I-PER'}
{'O': 0, 'B-LOC': 1, 'I-LOC': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-PER': 5, 'I-PER': 6}


In [27]:
from transformers import AutoTokenizer
import numpy as np

checkpoint = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

sentence = '海钓比赛地点在厦门与金门之间的海域。'
labels = [[7, 8, '厦门', 'LOC'], [10, 11, '金门', 'LOC']]

encoding = tokenizer(sentence, truncation=True)
tokens = encoding.tokens()
label = np.zeros(len(tokens),dtype=int)

for char_start, char_end, word, tag in labels:
    token_start = encoding.char_to_token(char_start)
    token_end = encoding.char_to_token(char_end)

    label[token_start] = label2id[f"B-{tag}"]
    label[token_start+1:token_end+1] = label2id[f"I-{tag}"]
print(tokens)
print(label)
print(id2label[id] for id in label)

['[CLS]', '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间', '的', '海', '域', '。', '[SEP]']
[0 0 0 0 0 0 0 0 1 2 0 1 2 0 0 0 0 0 0 0]
<generator object <genexpr> at 0x0000026132A506D0>


In [45]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np

checkpoint = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def collote_fn(batch_samples):
    batch_sentence, batch_tags  = [], []
    for sample in batch_samples:
        batch_sentence.append(sample['sentence'])
        batch_tags.append(sample['labels'])
    batch_inputs = tokenizer(
        batch_sentence,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    batch_label = np.zeros(batch_inputs['input_ids'].shape, dtype=int)
    for s_idx, sentence in enumerate(batch_sentence):
        encoding = tokenizer(sentence, truncation=True)
        batch_label[s_idx][0] = -100
        batch_label[s_idx][len(encoding.tokens())-1:] = -100
        for char_start, char_end, _, tag in batch_tags[s_idx]:
            token_start = encoding.char_to_token(char_start)
            token_end = encoding.char_to_token(char_end)
            batch_label[s_idx][token_start] = label2id[f"B-{tag}"]
            batch_label[s_idx][token_start+1:token_end+1] = label2id[f"I-{tag}"]
    return batch_inputs, torch.tensor(batch_label).long()

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)
test_dataloader = DataLoader(test_data, batch_size=4, shuffle=False, collate_fn=collote_fn)

In [46]:
batch_X, batch_y = next(iter(train_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('batch_y shape:', batch_y.shape)
print(batch_X)
print(batch_y)

batch_X shape: {'input_ids': torch.Size([4, 105]), 'token_type_ids': torch.Size([4, 105]), 'attention_mask': torch.Size([4, 105])}
batch_y shape: torch.Size([4, 105])
{'input_ids': tensor([[ 101,  800, 6432, 8024,  704, 1744, 1372, 3300,  671,  702, 8024, 1378,
         3968, 3221,  704, 1744, 4638,  671,  702, 4689, 8024, 6821, 3221, 1744,
         7354, 4852,  833, 2792, 1062, 6371, 4638,  511,  102,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0],
        [ 101, 1086, 3613, 8024, 4906, 2825, 7484, 1462, 1217, 6862, 1355, 2245,
         8024, 4294, 1166,

In [30]:
from torch import nn
from transformers import AutoConfig
from transformers import BertPreTrainedModel, BertModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
#BERT model + fc :
class BertForNER(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(768, len(id2label))
        self.post_init()

    def forward(self, x):
        bert_output = self.bert(**x)
        sequence_output = bert_output.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

config = AutoConfig.from_pretrained(checkpoint)
model = BertForNER.from_pretrained(checkpoint, config=config).to(device)
print(model)

Some weights of BertForNER were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using cuda device
BertForNER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-

In [31]:
outputs = model(batch_X.to(device))
print(outputs.shape)

torch.Size([4, 65, 7])


In [38]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)

    model.train()
    for batch, (X, y) in enumerate(dataloader, start=1):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred.permute(0, 2, 1), y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [39]:
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

y_true = [['O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'B-LOC', 'O'], ['B-PER', 'I-PER', 'O']]
y_pred = [['O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'B-LOC', 'O'], ['B-PER', 'I-PER', 'O']]

print(classification_report(y_true, y_pred, mode='strict', scheme=IOB2))

              precision    recall  f1-score   support

         LOC       0.50      0.50      0.50         2
         PER       1.00      1.00      1.00         1

   micro avg       0.67      0.67      0.67         3
   macro avg       0.75      0.75      0.75         3
weighted avg       0.67      0.67      0.67         3



In [36]:
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

def test_loop(dataloader, model):
    true_labels, true_predictions = [], []

    model.eval()
    with torch.no_grad():
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            predictions = pred.argmax(dim=-1).cpu().numpy().tolist()
            labels = y.cpu().numpy().tolist()
            true_labels += [[id2label[int(l)] for l in label if l != -100] for label in labels]
            true_predictions += [
                [id2label[int(p)] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
    print(classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2))

In [47]:
from transformers import AdamW, get_scheduler

learning_rate = 1e-5
epoch_num = 3

loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
    test_loop(valid_dataloader, model)
print("Done!")

Epoch 1/3
-------------------------------


  0%|          | 0/5216 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC       0.96      0.96      0.96      1951
         ORG       0.91      0.91      0.91       984
         PER       0.98      0.98      0.98       884

   micro avg       0.95      0.95      0.95      3819
   macro avg       0.95      0.95      0.95      3819
weighted avg       0.95      0.95      0.95      3819

Epoch 2/3
-------------------------------


  0%|          | 0/5216 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC       0.97      0.96      0.96      1951
         ORG       0.92      0.93      0.93       984
         PER       0.99      0.98      0.99       884

   micro avg       0.96      0.96      0.96      3819
   macro avg       0.96      0.96      0.96      3819
weighted avg       0.96      0.96      0.96      3819

Epoch 3/3
-------------------------------


  0%|          | 0/5216 [00:00<?, ?it/s]

  0%|          | 0/580 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC       0.97      0.96      0.97      1951
         ORG       0.93      0.93      0.93       984
         PER       0.99      0.98      0.99       884

   micro avg       0.96      0.96      0.96      3819
   macro avg       0.96      0.96      0.96      3819
weighted avg       0.96      0.96      0.96      3819

Done!
