In [None]:
!git clone https://github.com/bekou/multihead_joint_entity_relation_extraction
!pip install transformers==3.3.0
!pip install datasets
!pip install seqeval

Cloning into 'multihead_joint_entity_relation_extraction'...
remote: Enumerating objects: 160, done.[K
remote: Total 160 (delta 0), reused 0 (delta 0), pack-reused 160[K
Receiving objects: 100% (160/160), 171.31 MiB | 47.68 MiB/s, done.
Resolving deltas: 100% (76/76), done.
Collecting transformers==3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/3a/fc/18e56e5b1093052bacf6750442410423f3d9785d14ce4f54ab2ac6b112a6/transformers-3.3.0-py3-none-any.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 10.4MB/s 
[?25hCollecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 50.5MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.g

In [None]:
class RawSample(object):
  def __init__(self, doc_id, tokens, text, ner_labels, rel_labels):
    self.doc_id = doc_id
    self.tokens = tokens
    self.ner_labels = ner_labels
    self.rel_labels = rel_labels

def raw_dataset(path):
  with open(path, 'r') as f:
    lines = f.readlines()
    raw_dataset = []
    for line in lines:
      if line.strip().startswith('#'):
        example = RawSample(None, [], None, [], [])
        raw_dataset.append(example)
        example.doc_id = line.strip().split(' ')[-1]
      else:
        pos_id, token, ner, rel, head_id = line.strip().split('\t')
        example.tokens.append(token)
        example.ner_labels.append(ner)
        head_ids = head_id.strip('[').strip(']').split(',')

        if pos_id != head_ids[0]: 
          rels = rel.replace("'", '').strip("['").strip("']").split(',')
          for head_id, rel in zip(head_ids, rels):
            example.rel_labels.append((int(pos_id), int(head_id), rel.strip()))

  return raw_dataset

train_raw = raw_dataset('/content/multihead_joint_entity_relation_extraction/data/CoNLL04/train.txt')
dev_raw = raw_dataset('/content/multihead_joint_entity_relation_extraction/data/CoNLL04/dev.txt')
test_raw = raw_dataset('/content/multihead_joint_entity_relation_extraction/data/CoNLL04/test.txt')

In [None]:
import itertools
### ??? what if the new rel is unknown?
ner_labels = set(itertools.chain.from_iterable([example.ner_labels for example in train_raw]))
rel_labels = itertools.chain.from_iterable([example.rel_labels for example in train_raw])
rel_labels = set([rel[-1] for rel in rel_labels])
                                              
def label_lookup(unique_labels):
  label2idx = dict(zip(unique_labels, range(1, len(unique_labels)+1)))
  label2idx['NEGATIVE'], label2idx['IGNORE'] = 0, -100
  idx2label = {value: key for key, value in label2idx.items()}
  return label2idx, idx2label

ner_label2idx, ner_idx2label = label_lookup(ner_labels)
rel_label2idx, rel_idx2label = label_lookup(rel_labels)

In [None]:
import torch
import random
import numpy as np

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

# Data loader

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformers import RobertaTokenizerFast
from tokenizers import Encoding

In [None]:
example = train_raw[0]
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
tokenizer.add_prefix_space = True
encode = tokenizer(example.tokens, return_offsets_mapping= True, is_split_into_words= True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




In [None]:
tokens, ners, offset_mapping = example.tokens, example.ner_labels, encode.offset_mapping

In [None]:
def ner_bpe_mapping(example, offset_mapping):
  ner_label_ids = np.array([ner_label2idx[ner] for ner in example.ner_labels])
  offset_mapping = np.array(offset_mapping)

  ner_bpe_label = np.ones(len(offset_mapping))*(-100)
  bpe_mapping = np.ones(len(offset_mapping))*(-100)

  condition = (offset_mapping[:, 0] == 0) & (offset_mapping[:, 1] != 0)
  ner_bpe_label[condition] = ner_label_ids
  bpe_mapping[condition] = np.arange(len(example.tokens))

  return ner_bpe_label, bpe_mapping

def creat_rel_label(example, bpe_mapping):
  _len = len(bpe_mapping)
  rel_label = np.ones((_len, _len), dtype = np.int)*-100

  rels = [rel_label2idx[rel[-1]] for rel in example.rel_labels]
  start_entities = [rel[0] for rel in example.rel_labels]
  end_entities = [rel[1] for rel in example.rel_labels]
  rel_label[start_entities, end_entities] = rels

  return rel_label

In [None]:
import torch
class ColDataset(object):
  def __init__(self,tokenizer, raw_data, max_length = 169):
    self.tokenizer = tokenizer
    self.raw_data = raw_data
    self.max_length = max_length
  
  def __getitem__(self, idx):
    example = self.raw_data[idx]
    encode = self.tokenizer(example.tokens, return_offsets_mapping= True, 
                       is_split_into_words= True, truncation=True, 
                       padding= 'max_length', max_length = self.max_length)
    
    ner_label, bpe_mapping = ner_bpe_mapping(example, encode.offset_mapping)
    rel_label = creat_rel_label(example, bpe_mapping)
    token_type_ids = np.zeros(self.max_length)
    return (torch.LongTensor(encode.input_ids),
            torch.FloatTensor(encode.attention_mask),
            torch.LongTensor(token_type_ids),
            torch.LongTensor(ner_label),
            torch.LongTensor(rel_label))
  
  def __len__(self): return len(self.raw_data)

In [None]:
train_ds = ColDataset(tokenizer, train_raw)
dev_ds = ColDataset(tokenizer, dev_raw)
test_ds = ColDataset(tokenizer, test_raw)

train_dl = DataLoader(train_ds, batch_size= 8, shuffle= True)
dev_dl = DataLoader(dev_ds, batch_size= 16, shuffle= True)
test_dl = DataLoader(test_ds, batch_size= 16, shuffle= True)

In [None]:
for batch in train_dl:
  for item in batch: print(item.shape)
  break
input_ids, attention_mask, token_type_ids, ner_labels, rel_labels = batch

torch.Size([8, 169])
torch.Size([8, 169])
torch.Size([8, 169])
torch.Size([8, 169])
torch.Size([8, 169, 169])


# Model

In [None]:
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig
transformer_model = RobertaModel.from_pretrained('roberta-large')

from transformers import RobertaConfig
config = RobertaConfig.from_pretrained('roberta-large')
config.hidden_size

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=482.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1425941629.0, style=ProgressStyle(descr…




In [None]:
# class RelModel(nn.Module):
#   def __init__(self, ner_hidden_size = 256, num_ner_class = 10, rel_hidden_size = 256, num_rel_class = ):
#     self.transformer_model = transformer_model
#     self.ner_linear1 = nn.Linear(config.hidden_size, ner_hidden_size)
#     self.ner_linear2 = nn.Linear(ner_hidden_size, num_ner_class)

#     self.
  
  



In [None]:
len(rel_label2idx)

7