In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF
from deeppavlov.models.embedders.fasttext_embedder import FasttextEmbedder
import re
import json
import nltk
nltk.download("punkt")
from nltk import sent_tokenize, word_tokenize
from nltk.tokenize.util import align_tokens
from glob import glob
from functools import partial
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook as tqdm
from conlleval import evaluate as prec_rec_f
from brat_format import read_file, BratDoc

[nltk_data] Downloading package punkt to /home/dmitry/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [29]:
def span_sentences(text, shift=0):
    """
    Extracts sentences and their spans from text.

    Parameters
    text : str
        Text to extract sentences and spans from.
    shift : int
        Initial position from which to start counting span.

    Returns
    sents : List[str]
        Sentences extracted from text.
    spans : List[Tuple[int, int]]
        Extracted sentences position in text.
    """
    
    sents = sent_tokenize(text, language="russian")
    spans = align_tokens(sents, text)
    spans = [(start + shift, end + shift) for start, end in spans]
    
    return sents, spans


def span_tokens(text, shift=0):
    """
    Extracts tokens and their spans from text.

    Parameters
    text : str
        Text to extract tokens and spans from.
    shift : int
        Initial position from which to start counting span.

    Returns
    tokens : List[str]
        Tokens extracted from text.
    spans : List[Tuple[int, int]]
        Extracted tokens position in text.
    """
    
    tokens, spans = [], []

    for tok in re.finditer(r"([^\W_]+|\S)", text):
        tokens.append(tok.group(1))
        spans.append((shift + tok.start(1), 
                      shift + tok.end(1)))
    
    return tokens, spans


def to_conll(brat_ners, spans):
    """
    Converts named entities from brat to conll format. In conll format every 
    token has a tag:
    B-named_entity_type - for the first token in named entity,
    I-named_entity_type - for a token of named entity that is not first,
    O - for a token out of named entity.

    Parameters
    brat_ners : List[Dict]
        Named entities in brat format.
    spans : List[Tuple[int, int]]
        Position of tokens in reference text.

    Returns
    conll_ners : List[str]
        Conll tags of the tokens corresponding to spans.
    """
    
    conll_ners = []

    for token_start, token_end in spans:
        
        for ner in brat_ners:
            
            if (ner["start"] <= token_start) and (ner["end"] >= token_end):
                prefix = "I" if (ner["start"] < token_start) else "B"
                conll_ners.append(prefix + "-" + ner["ner_type"])
                break
        
        else:
            conll_ners.append("O")  
    
    return conll_ners


def to_brat(conll_ners, spans, ner_id=1):
    """
    Converts named entities from conll to brat format. In brat format every 
    named entity is represented with its id, type, and position in reference 
    text.

    Parameters
    conll_ners : List[str]
        Conll tags of the tokens corresponding to spans.
    spans : List[tuple[int]]
        Position of tokens in reference text.
    ner_id : int
        The initial id from which to start counting ner_ids

    Returns
    brat_ners : List[Dict]
        Named entities in brat format.
    """

    brat_ners = []
    prev = "O"

    for tag, (token_start, token_end)  in zip(conll_ners, spans):
        splitted_tag = tag.split("-")
        
        if len(splitted_tag) > 1:
            prefix, ner_type = splitted_tag
            
            if prefix == "I":
                
                if prev != "O":
                    brat_ners[-1]["end"] = token_end
                    prev = "I"
                    continue
            
            brat_ners.append({"ner_id": ner_id, 
                              "ner_type": ner_type, 
                              "start": token_start, 
                              "end": token_end})
            prev = "B"
            ner_id += 1
        
        else:
            prev = "O"

    return brat_ners


def extract_data(files):
    """
    Given text sequence as tokens, predicts corresponding conll tags.

    Parameters
    files : List[str]
        Paths to .ann files to extract data from.

    Returns
    tokens : List[List[str]]
        Tokenized text sequences.
    tags : List[List[str]]
        Conll tags corresponding to token sequences.
    """

    tokens, tags = [], []

    for file_path in tqdm(files):
        brat_doc = read_file(file_path)
        doc_ners = [{"id": i, 
                    "ner_type": brat_doc.ners[idx][0], 
                    "start": brat_doc.ners[idx][1], 
                    "end": brat_doc.ners[idx][2]} 
                    for i, idx in brat_doc.ner_id_2_idx.items()]
        
        for line in re.finditer(r"[^\n]+(\n+|$)", brat_doc.txt_data):
            sents, sent_spans = span_sentences(line.group(0), shift=line.start())
            
            for sent, (sent_start, _) in zip(sents, sent_spans):
                toks, spans = span_tokens(sent, shift=sent_start)
                tokens.append(toks)
                tags.append(to_conll(doc_ners, spans))

    return tokens, tags

In [34]:
files = glob("data/train/*.ann")
tokens, tags = extract_data(files)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for file_path in tqdm(files):


  0%|          | 0/188 [00:00<?, ?it/s]

In [35]:


files = glob("data/train/*.ann")
tokens, tags = extract_data(files)
     


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for file_path in tqdm(files):


  0%|          | 0/188 [00:00<?, ?it/s]

In [36]:
train_tokens, val_tokens, train_tags, val_tags = train_test_split(tokens, tags, 
                                                                  test_size=0.1)
len(train_tokens), len(val_tokens)

(19026, 2115)

In [37]:


class NER_Dataset(Dataset):
    def __init__(self, tag2id, seqs, seq_tags):
        self.tag2id = tag2id
        self.seqs = [[token.lower() for token in seq] for seq in seqs]
        self.seq_tags = [[self.tag2id[tag] for tag in tags] for tags in seq_tags]

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        return self.seqs[idx], self.seq_tags[idx]
     


In [38]:


# Conll tags encoding
tags = list({tag for sent in train_tags for tag in sent})
tag2id = {tag: i for i, tag in enumerate(tags)}
id2tag = {i: tag for i, tag in enumerate(tags)}

with open("tags.json", "w") as f:
    json.dump(tags, f)

tags
     


['I-ACT',
 'I-SOC',
 'B-SOC',
 'B-ACT',
 'B-BIN',
 'B-MET',
 'I-INST',
 'I-QUA',
 'I-CMP',
 'I-ECO',
 'B-QUA',
 'O',
 'I-MET',
 'I-BIN',
 'B-CMP',
 'B-ECO',
 'B-INST']

In [39]:


train_ds = NER_Dataset(tag2id, train_tokens, train_tags)
val_ds = NER_Dataset(tag2id, val_tokens, val_tags)
     


In [40]:


val_ds[:2]
     


([['органами',
   'исполнительной',
   'власти',
   'алтайского',
   'края',
   'совместно',
   'с',
   'алтайкрайстатом',
   'будут',
   'ежегодно',
   'определяться',
   'тематики',
   'аналитических',
   'материалов',
   'по',
   'актуальным',
   'направлениям',
   'социально',
   '-',
   'экономического',
   'развития',
   'алтайского',
   'края',
   '.'],
  ['3', ')']],
 [[16,
   6,
   6,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   11,
   10,
   11,
   15,
   9,
   9,
   14,
   11,
   11,
   11],
  [11, 11]])

In [41]:


class BiLSTM_CRF(nn.Module):
    def __init__(self, embedding_size, hidden_size, feature_dim, num_classes, 
                 dropout):
        super().__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.dropout = dropout

        self.lstm = nn.LSTM(embedding_size, hidden_size, 2, bidirectional=True, 
                            batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.fc_0 = nn.Linear(2 * hidden_size, feature_dim)
        self.Q = nn.Linear(feature_dim, feature_dim)
        self.K = nn.Linear(feature_dim, feature_dim)
        self.V = nn.Linear(feature_dim, feature_dim)
        self.layer_norm = nn.LayerNorm(feature_dim)
        self.fc_1 = nn.Linear(feature_dim, num_classes)
        self.crf = CRF(num_classes, batch_first=True)


    def forward(self, x, lengths):
        # LSTM
        x_packed = pack_padded_sequence(x, lengths, batch_first=True)
        seq_out_packed, _ = self.lstm(x_packed)
        seq_out, _ = pad_packed_sequence(seq_out_packed, batch_first=True)
        seq_out = self.drop(seq_out)
        seq_out = self.fc_0(F.relu(seq_out))

        # Attention
        Q, K, V = self.Q(seq_out), self.K(seq_out), self.V(seq_out)
        attn = torch.bmm(Q, K.transpose(1, 2))
        attn /= torch.sqrt(torch.tensor(self.feature_dim, dtype=torch.float))
        attn = F.softmax(attn, dim=-1)
        out = torch.bmm(attn, V)
        out = self.layer_norm(out)

        scores = self.fc_1(out)

        return scores
     


In [42]:


device = "cuda" if torch.cuda.is_available() else "cpu"
elmo_embedder = FasttextEmbedder("http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-news_wmt11-16_1.5M_steps.tar.gz", 
                             elmo_output_names=["elmo"])
     


NameError: name 'ELMoEmbedder' is not defined