In [41]:
import sys
sys.path.append("/work/multi_doc_analyzer")
sys.path.append("/work/relation_extraction/Bert_model/data/")

import torch as T
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.cuda
from allennlp.nn import util as nn_util
from multi_doc_analyzer.structure.structure import *
from multi_doc_analyzer.tokenization.tokenizer import MDATokenizer
from tqdm import tqdm

from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.data.token_indexers import TokenIndexer

from allennlp.data.instance import Instance
from allennlp.data.fields import TextField, LabelField, ArrayField

from ace05_set_reader import ACE05Reader

from allennlp.data.vocabulary import Vocabulary
from allennlp.data.iterators import BucketIterator, DataIterator, BasicIterator
from allennlp.nn.util import get_text_field_mask
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder
import random

from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

In [42]:
train_path = "/work/LDC2006T06/dataset/dev/"
test_path = "/work/LDC2006T06/dataset/test/"
model_folder = "/work/model_checkpoint/bert_model_checkpoint/"
output_path = "/work/relation_extraction/Bert_model/model/"

In [43]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    seed=1,
    batch_size=64,
    lr=3e-4,                # learning rate
    epochs=50,
    hidden_sz=128,
    arg_sz=3,
    max_seq_len=110
)

In [44]:
USE_GPU = T.cuda.is_available()
USE_GPU

True

In [45]:
# set seed for both CPU and CUDA
T.manual_seed(config.seed)

<torch._C.Generator at 0x7f882e67b810>

In [46]:
# from ace05_set_reader import ACE05Reader
# train_path = "/work/LDC2006T06/dataset/train/"
# reader = ACE05Reader(lang='en')
# doc_dicts = reader.read(train_path)
# tokenizer = MDATokenizer('bert-en')
# for doc in doc_dicts.values():
#     tokenizer.annotate_document(doc)
#     for s in doc.sentences: 
        

In [47]:
e_type2idx = {'X':0, 'O': 1, 'PER': 2, 'ORG': 3, 'LOC': 4, 'GPE': 5, 'FAC': 6, 'VEH': 7, 'WEA': 8}

r_label2idx = {'PHYS-lr': 1, 'PART-WHOLE-lr': 2, 'PER-SOC-lr': 3, 'ORG-AFF-lr': 4, 'ART-lr': 5, 'GEN-AFF-lr': 6,
               'PHYS-rl': 7, 'PART-WHOLE-rl': 8, 'PER-SOC-rl': 9, 'ORG-AFF-rl': 10, 'ART-rl': 11, 'GEN-AFF-rl': 12,
               'NONE': 0}

r_idx2label = {v: k for k, v in r_label2idx.items()}

class RelationDatasetReader(DatasetReader):
    """
    Reads Structure object formatted datasets files, and creates AllenNLP instances.
    """
    def __init__(self, tokenizer: Tokenizer=None, token_indexers: Dict[str, TokenIndexer]=None, 
                 MAX_WORDPIECES: int=config.max_seq_len, 
                 is_training = False, ace05_reader: ACE05Reader=None):
        # make sure results may be reproduced when sampling...
        super().__init__(lazy=False)
        random.seed(0)
        self.is_training = is_training
        self.ace05_reader = ace05_reader
        
        # NOTE AllenNLP automatically adds [CLS] and [SEP] word peices in the begining and end of the context,
        # therefore we need to subtract 2
        self.MAX_WORDPIECES = MAX_WORDPIECES - 2
        
        self.tokenizer = tokenizer or WordTokenizer()
        
        # BERT specific init
        self._token_indexers = token_indexers

    def text_to_instance(self, sentence: Sentence) -> Instance:
#         sentence_tokens = [Token(x) for x in self.tokenizer(sentence.text)]
#         sentence_tokens = [Token(i) for t in sentence.text for i in self.tokenizer(t)]
#         sentence_tokens = [Token(text=t) for t in sentence.text]
        sentence_tokens = []
#         td = self.tokenizer(sentence.text)
#         print(td)
        for t in sentence.tokens:
#             print(type(t.text), t.text)
            sentence_tokens.append(Token(text=t.text))
#         for t in sentence.text:
#             td = self.tokenizer(t)
#             assert len(td) == 1 or len(td) == 0
#             if td:
#                 sentence_tokens.append(Token(text=td[0]))
#             else:
#                 sentence_tokens.append(Token(text='[MASK]'))

        sentence_field = TextField(sentence_tokens, self._token_indexers)
        fields = {"tokens": sentence_field}
#         char_list_field = ListField([t for t in sentence.text])
#         fields['char_list'] = char_list_field

        e_tuple_check_dicts = {} # {(train_arg_l.id, train_arg_r.id):true_label, ...}
        if self.is_training: 
            for r in sentence.relation_mentions:
                train_arg_l, train_arg_r, true_label = r.get_left_right_args()
                e_tuple_check_dicts[(train_arg_l.id, train_arg_r.id)] = true_label

        # construct pair entities
        for arg_left_idx in range(len(sentence.entity_mentions)-1):
            for arg_right_idx in range(arg_left_idx+1, len(sentence.entity_mentions)):
                arg_left = sentence.entity_mentions[arg_left_idx]
                arg_right = sentence.entity_mentions[arg_right_idx]
                arg_vec = T.zeros(self.MAX_WORDPIECES + 2, dtype=T.long)
                arg_vec[:len(sentence_tokens)+2] = 1

                # +1 because the first token is [CLS]
                arg_vec[arg_left.token_b+1:arg_left.token_e+1] = e_type2idx[arg_left.type]
                arg_vec[arg_right.token_b+1:arg_right.token_e+1] = e_type2idx[arg_right.type]
                fields["arg_idx"] = ArrayField(arg_vec)


#                 fields["arg_left"] = SpanField(arg_left.char_b, arg_left.char_e, char_list_field)
#                 fields["arg_right"] = SpanField(arg_right.char_b, arg_right.char_e, char_list_field)
                if self.is_training:
                    if (arg_left.id, arg_right.id) in e_tuple_check_dicts.keys():
                        fields["label"] = LabelField(r_label2idx[e_tuple_check_dicts[(arg_left.id, arg_right.id)]], skip_indexing=True)
                    else:
                        fields["label"] = LabelField(r_label2idx['NONE'], skip_indexing=True)
                yield Instance(fields)
    
    def _read(self, file_path: str)->Iterator: 
        doc_dicts = self.ace05_reader.read(file_path)
        tokenizer = MDATokenizer('bert-en')
        for doc in doc_dicts.values():
            tokenizer.annotate_document(doc)
            for s in doc.sentences: 
                if len(s.tokens) <= config.max_seq_len:
                    for instance in self.text_to_instance(s):
                        yield instance

In [48]:
class BERT(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                out_sz: int=len(r_label2idx)):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self._entity_embeddings = T.nn.Embedding(num_embeddings=len(e_type2idx), embedding_dim=config.arg_sz, padding_idx=0)
        self.gru = T.nn.GRU(word_embeddings.get_output_dim()+config.arg_sz, config.hidden_sz, batch_first=True)
        self.projection = nn.Linear(config.hidden_sz, out_sz)
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, tokens: Dict[str, T.tensor], arg_idx: T.tensor, label: T.tensor = None) -> Dict[str, T.tensor]:
        mask = get_text_field_mask(tokens)
        
        embeddings = self.word_embeddings(tokens)
        pad_len = embeddings.shape[-2]
        
        arg_idx = arg_idx[:,:pad_len]
        arg_idx = arg_idx.type(T.long)
        
        arg_emb = self._entity_embeddings(arg_idx)

        concat = T.cat((embeddings, arg_emb), -1)
        ot, ht = self.gru(concat, None) # revise this "None"
        ot = ot[:,-1,:]    
        class_logits = self.projection(ot)
        
        output = {"class_logits": class_logits}
        output["loss"] = self.loss(class_logits, label)

        return output

In [49]:
from scipy.special import expit # the sigmoid function
def tonp(tsr): return tsr.detach().cpu().numpy()

In [50]:
# Predict
class Predictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> np.ndarray:
        out_dict = self.model(**batch)
        return expit(tonp(out_dict["class_logits"]))
    
    def predict(self, ds: Iterable[Instance]) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(pred_generator, total=self.iterator.get_num_batches(ds))
        preds = []
        with T.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return np.concatenate(preds, axis=0)

In [51]:
def plot_comfusion_matrix(label_classes, predict_classes, out_folder, file_name):
    label_types = list(r_idx2label.values())
    print(type(label_classes), len(label_classes))
    cm = confusion_matrix(label_classes, predict_classes, label_types)
    print(cm)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(cm)
    for (i, j), z in np.ndenumerate(cm):
        ax.text(j, i, '{:0.0f}'.format(z), ha='center', va='center', color='white')
    fig.colorbar(cax)
    ax.set_xticklabels([''] + label_types)
    ax.set_yticklabels([''] + label_types)
    plt.xlabel('Predicted')
    plt.ylabel('True')

    plt.savefig(out_folder + 'confusion_matrix_' + file_name + '.png')
    plt.show()

    print("Accuracy:", )
    print("F1 score:", )

In [52]:

if __name__ == '__main__':

    ace05_reader = ACE05Reader(lang='en')
    
    token_indexer = PretrainedBertIndexer(
        pretrained_model="bert-base-uncased",
        max_pieces=config.max_seq_len
#         truncate_long_sequences=False,
#         do_lowercase=False               # for cased condition
    )
 
	# AllenNLP DatasetReader
    reader = RelationDatasetReader(
        is_training=True, 
        ace05_reader=ace05_reader, 
        tokenizer=lambda s: token_indexer.wordpiece_tokenizer(s),
        token_indexers={"tokens": token_indexer}
    )

    train_ds = reader.read(train_path)
    print(len(train_ds))
    vocab = Vocabulary()
    iterator = BucketIterator(batch_size=config.batch_size, sorting_keys=[("tokens", "num_tokens")])
    iterator.index_with(vocab)

    bert_embedder = PretrainedBertEmbedder(
        pretrained_model="bert-base-uncased",
        top_layer_only=True, # conserve memory   
    )
    word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                                # we'll be ignoring masks so we'll need to set this to True
                                                               allow_unmatched_keys = True)
    model = BERT(word_embeddings)
    if USE_GPU:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=config.lr)

0it [00:00, ?it/s]
  0%|          | 0/80 [00:00<?, ?it/s][A
 11%|█▏        | 9/80 [00:00<00:00, 88.07it/s][A
 34%|███▍      | 27/80 [00:00<00:00, 103.56it/s][A
 55%|█████▌    | 44/80 [00:00<00:00, 86.38it/s] [A
 75%|███████▌  | 60/80 [00:00<00:00, 99.94it/s][A
 91%|█████████▏| 73/80 [00:00<00:00, 106.48it/s][A
15897it [00:09, 1623.43it/s]0<00:00, 107.03it/s][A

15897





In [53]:
    # training
    from allennlp.training.trainer import Trainer

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=iterator,
        train_dataset=train_ds,
        cuda_device=0 if USE_GPU else -1,
        num_epochs=config.epochs,
    )

In [None]:
    # train the model 
    metrics = trainer.train()

loss: 0.4402 ||: 100%|██████████| 249/249 [00:28<00:00,  8.60it/s]
loss: 0.2565 ||: 100%|██████████| 249/249 [00:26<00:00,  9.50it/s]
loss: 0.1321 ||: 100%|██████████| 249/249 [00:25<00:00,  9.60it/s]
loss: 0.0762 ||: 100%|██████████| 249/249 [00:26<00:00,  9.56it/s]
loss: 0.0492 ||: 100%|██████████| 249/249 [00:26<00:00,  9.94it/s]
loss: 0.0369 ||: 100%|██████████| 249/249 [00:26<00:00,  9.49it/s]
loss: 0.0246 ||: 100%|██████████| 249/249 [00:26<00:00,  9.45it/s]
loss: 0.0188 ||: 100%|██████████| 249/249 [00:26<00:00,  9.47it/s]
loss: 0.0123 ||: 100%|██████████| 249/249 [00:26<00:00,  9.40it/s]
loss: 0.0127 ||: 100%|██████████| 249/249 [00:26<00:00,  8.58it/s]
loss: 0.0079 ||: 100%|██████████| 249/249 [00:26<00:00, 10.88it/s]
loss: 0.0055 ||: 100%|██████████| 249/249 [00:26<00:00,  9.81it/s]
loss: 0.0037 ||: 100%|██████████| 249/249 [00:26<00:00,  9.38it/s]
loss: 0.0035 ||: 100%|██████████| 249/249 [00:26<00:00, 11.47it/s]
loss: 0.0024 ||: 100%|██████████| 249/249 [00:26<00:00,  9.42i

In [None]:
    # load model
#     model.load_state_dict(T.load(config.model_folder + "/model.th"))

In [None]:
    # save 
    with open(model_folder+'model.th', 'wb') as f:
        T.save(model.state_dict(), f)

In [None]:
    # training data analysis
    seq_iterator = BasicIterator(batch_size=config.batch_size)
    seq_iterator.index_with(vocab)
    
    predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
    train_preds = predictor.predict(train_ds) 
    
    label_types = [r_idx2label.get(i.fields['label'].label) for i in train_ds]
    predict_types = [r_idx2label.get(i) for i in np.argmax(train_preds, axis=-1)]

In [None]:
    plot_comfusion_matrix(label_types, predict_types, output_path, "train")

In [None]:
    # testing data analysis
    
    # AllenNLP DatasetReader
    reader = RelationDatasetReader(
        is_training=True, 
        ace05_reader=ace05_reader, 
        tokenizer=lambda s: token_indexer.wordpiece_tokenizer(s),
        token_indexers={"tokens": token_indexer}
    )
    
    test_ds = reader.read(test_path)
    
    seq_iterator = BasicIterator(batch_size=config.batch_size)
    seq_iterator.index_with(vocab)
    
    predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
    test_preds = predictor.predict(test_ds) 
    
    label_types = [r_idx2label.get(i.fields['label'].label) for i in test_ds]
    predict_types = [r_idx2label.get(i) for i in np.argmax(test_preds, axis=-1)]    

In [None]:
    plot_comfusion_matrix(label_types, predict_types, output_path, "test")