In [1]:
import json
import pandas as pd
import numpy as np
import seaborn as sns
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from seqeval.metrics import f1_score, accuracy_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(device, n_gpu)

cuda 1


In [3]:
TOKEN_LEN = 150
S_LEN = 256
bs = 16
LABELS = ['Exhibit', 'Manoeuvre', 'Deploy', 'Experiment', 'Accident', 'Indemnity', 'Support']

In [4]:
TOKEN_LABELS = ['O', 'PAD'] + ['B-' + x for x in LABELS] + ['I-' + x for x in LABELS]
TOKEN_MAPPING = {y: x for x, y in enumerate(TOKEN_LABELS)}
TOKEN_MAPPING

{'O': 0,
 'PAD': 1,
 'B-Exhibit': 2,
 'B-Manoeuvre': 3,
 'B-Deploy': 4,
 'B-Experiment': 5,
 'B-Accident': 6,
 'B-Indemnity': 7,
 'B-Support': 8,
 'I-Exhibit': 9,
 'I-Manoeuvre': 10,
 'I-Deploy': 11,
 'I-Experiment': 12,
 'I-Accident': 13,
 'I-Indemnity': 14,
 'I-Support': 15}

In [5]:
with open('data/train_7000.json', 'r') as f:
    tr_data = json.load(f)
with open('data/valid_1500.json', 'r') as f:
    val_data = json.load(f)
with open('data/test_dataset_A.json', 'r') as f:
    test_data = json.load(f)

In [6]:
TOKENIZER = AutoTokenizer.from_pretrained('bert-base-chinese')

# dataset

In [7]:
class DataSet(object):
    def __init__(self, tokens, max_token_len=TOKEN_LEN, max_s_len=S_LEN, labels=None):
        self.tokens = [x + ['。'] for x in tokens]
        self.max_token_len = max_token_len
        self.max_s_len = max_s_len
        self.token_idxes = self._get_tokens_idxes()
        self.encoded = self._encode_setences()
        self.labels = labels and self._get_label(labels)

    def _get_tokens_idxes(self):
        token_idxes = []
        for s in self.tokens:
            token_idx = [0] + np.cumsum([len(x) + 1 for x in s]).tolist()
            token_idxes.append(token_idx)
#             is_token_start = np.zeros(len(setence))
#             is_token_start[token_idx] = 1
        return token_idxes
        
    def _encode_setences(self):
        setences = [' '.join(s) for s in self.tokens]
        return TOKENIZER(setences, padding='max_length', return_offsets_mapping=True, max_length=S_LEN, return_tensors='np')
    
    def _get_label(self, labels):
        result = np.zeros((len(labels), self.max_s_len)) + TOKEN_MAPPING['PAD'] * (1 - self.encoded.attention_mask)
        for i, (event_type, pos) in enumerate(labels):
            if event_type:
                start, end = pos
                id_setence_mapping = {x: idx for idx, (x, y) in enumerate(self.encoded.offset_mapping[i]) if y}
#                 id_setence_mapping[sum([len(x) for x in self.tokens[i]])] = 
                token_idx = self.token_idxes[i]
                try:
                    result[i, id_setence_mapping[token_idx[start]]] = TOKEN_MAPPING['B-' + event_type]
                    result[i, id_setence_mapping[token_idx[start]] + 1: id_setence_mapping[token_idx[end]]] = TOKEN_MAPPING['I-' + event_type]
                except Exception as e:
#                     print(self.encoded)
                    print(id_setence_mapping)
                    print(token_idx)
                    print(i, start, end)
                    print(e)
                    raise e
        return result
    
    def to_tensor_data(self):
        input_ids, masks = torch.tensor(self.encoded.input_ids, dtype=torch.long), torch.tensor(self.encoded.attention_mask, dtype=torch.bool)
        if self.labels is not None:
            return TensorDataset(input_ids, masks, torch.tensor(self.labels, dtype=torch.long))
        else:
            return TensorDataset(input_ids, masks)
            
        
def transform_data(dataset, has_label=True):
    tokens = [x['tokens'] for x in dataset]
    triggers = None
    if has_label:
        triggers = []
        for item in dataset:
            event = item['event_mention']
            if event:
                event_type = event['event_type']
                start, end = event['trigger']['offset']
                if not item['tokens'][start].isalnum():
                    start += 1
                if not item['tokens'][end - 1].isalnum():
                    end -= 1
                triggers.append((event_type, (start, end)))
            else:
                triggers.append((None, None))
    return DataSet(tokens, labels=triggers)

In [8]:
tr_dataset = transform_data(tr_data)
val_dataset = transform_data(val_data)
test_dataset = transform_data(test_data, has_label=False)

In [9]:
train_data = tr_dataset.to_tensor_data()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = val_dataset.to_tensor_data()
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

# model define

In [10]:
class BertCrf(torch.nn.Module):
    def __init__(self, num_tags):
        super().__init__()
        self.bert = BertForTokenClassification.from_pretrained(
            "bert-base-chinese",
            num_labels=num_tags,
            output_attentions = False,
            output_hidden_states = False
        )
        
    def forward(self, x, attention_mask, labels):
        out = self.bert(x, token_type_ids=None, attention_mask=attention_mask, labels=labels)
        return out

In [11]:
model = BertForTokenClassification.from_pretrained(
    "bert-base-chinese",
    num_labels=len(TOKEN_MAPPING),
    output_attentions = False,
    output_hidden_states = False
)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

In [12]:
model.to(device)

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

# finetunning

In [13]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-6,
    eps=1e-8
)

In [14]:
epochs = 50
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [None]:
## Store the average loss after each epoch so we can plot them.
loss_values, validation_loss_values = [], []
valid_acc, valid_f1 = [], []
for _ in range(epochs):
    # ========================================
    #               Training
    # ========================================
    # Perform one full pass over the training set.

    # Put the model into training mode.
    model.train()
    # Reset the total loss for this epoch.
    total_loss = 0

    # Training loop
    for batch in tqdm(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # Always clear any previously calculated gradients before performing a backward pass.
        model.zero_grad()
        # forward pass
        # This will return the loss (rather than the model output)
        # because we have provided the `labels`.
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        # get the loss
        loss = outputs.loss
        # Perform a backward pass to calculate the gradients.
        loss.backward()
        # track train loss
        total_loss += loss.item()
        # Clip the norm of the gradient
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        # Update the learning rate.
        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)


    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.

    # Put the model into evaluation mode
    model.eval()
    # Reset the validation loss for this epoch.
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # Telling the model not to compute or store gradients,
        # saving memory and speeding up validation
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            # This will return the logits rather than the loss because we have not provided labels.
            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        # Move logits and labels to CPU
        predict_labels = outputs.logits.argmax(2)
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        eval_loss += outputs.loss.item()
        seq_len = b_input_mask.sum(1)
        predictions.extend([x[:seq_len[i]] for i, x in enumerate(predict_labels)])
        true_labels.extend([x[:seq_len[i]] for i, x in enumerate(label_ids)])

    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)
    print("Validation loss: {}".format(eval_loss))
    predictions = [[TOKEN_LABELS[y] for y in x] for x in predictions]
    valid_labels = [[TOKEN_LABELS[y] for y in x if y != TOKEN_MAPPING['PAD']] for x in true_labels]
    acc, f1 = accuracy_score(predictions, valid_labels), f1_score(predictions, valid_labels)
    print("Validation Accuracy: {}".format(acc), "Validation F1-Score: {}".format(f1))
    valid_acc.append(acc)
    valid_f1.append(f1)
    print()

100%|██████████| 438/438 [01:54<00:00,  3.82it/s]


Average train loss: 0.18269108624099734
Validation loss: 0.03527182188043569
Validation Accuracy: 0.991140145071429 Validation F1-Score: 0.4020752269779507



100%|██████████| 438/438 [02:01<00:00,  3.60it/s]


Average train loss: 0.033828101642897616
Validation loss: 0.031234228250352627
Validation Accuracy: 0.9914816384091469 Validation F1-Score: 0.4801950030469226



 19%|█▉        | 83/438 [00:23<01:42,  3.48it/s]

In [None]:
outputs.loss

In [None]:
sns.lineplot(x=range(len(valid_acc)), y=valid_acc)

In [None]:
sns.lineplot(x=range(len(valid_f1)), y=valid_f1)