In [None]:
import math
import pickle
import pandas as pd
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertConfig, BertModel
from transformers import get_linear_schedule_with_warmup
from torch import nn, optim, cuda, Tensor
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader

from natasha import NewsSyntaxParser, NewsEmbedding, Doc, Segmenter

from razdel import sentenize, tokenize
import re
import copy
import numpy as np
from tqdm import tqdm
import logging

In [None]:
import warnings

warnings.simplefilter("ignore", UserWarning)

In [None]:
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cuda


In [None]:
logger = logging.getLogger('Testing Number Version')
logger.setLevel(logging.INFO)

In [None]:
format_logger = logging.Formatter(
    '%(message)s')

filehandler_logger= logging.FileHandler('/home/logs/testing_number_copy_35.log')

filehandler_logger.setFormatter(format_logger)
logger.addHandler(filehandler_logger)

# Dataset utils

In [None]:
# only evaluation and training
class dataset(Dataset):
    '''
    Возвращает закодированную последовательность:
    индексы
    маску текста
    матрицу синтаксической смежности (для AGGCN)
    '''
    def __init__(self, texts, span_labels=None, max_len=None):
        self.len = len(texts)
        self.data = texts
        self.span_labels = span_labels
        self.max_len = max_len
        self.bert_tokenizer = BertTokenizerFast.from_pretrained('DeepPavlov/rubert-base-cased')
        emb = NewsEmbedding()
        self.syntax_parser = NewsSyntaxParser(emb)

    def __getitem__(self, index):

        temp_text = tokenize(self.data['Contents'][index])
                
        tokenized_text = [tok.text for tok in temp_text][:self.max_len]

        input_ids = [0] * self.max_len
        attention_mask = [0] * self.max_len
        for idx, word in enumerate(tokenized_text):
            input_ids[idx] = self.bert_tokenizer.encode(word,add_special_tokens=False)[0]
            attention_mask[idx] = 1

        markup = self.syntax_parser(tokenized_text).tokens
        adj_matrix = np.zeros((self.max_len, self.max_len))
        for tok in markup:
            i = int(tok.id)-1
            j = int(tok.head_id)-1
            if i < 0 or j < 0:
                continue
            adj_matrix[i][j] = 1
            adj_matrix[j][i] = 1
        
        # сохраняем путь до нужных данных, чтобы при валидации найти нужную часть датафрейма
        # т.о. нет необходимости сохранять все в тензоре одного размера
        address_true_spans = [self.data['TextID'][index],
                              self.data['SentID'][index]]
        
        item = {'input_ids': torch.as_tensor(input_ids),
                'mask': torch.as_tensor(attention_mask),
                'adj': torch.as_tensor(adj_matrix),
                'address': address_true_spans}        
        return item

    def __len__(self):
        return self.len

In [None]:
def get_ind_sequence(dataframe):

    '''
    Возвращает последовательность токенов в виде индексов начала и конца этого токена
    '''

    temp_df = {'TextID':[], 'SentID':[], 'Indices':[]}
    for row in range(len(dataframe)):
        textid = dataframe['TextID'].iloc[row]
        sentid = dataframe['SentID'].iloc[row]
        content = dataframe['Contents'].iloc[row]

        saved_indices = [[tok.start, tok.stop] for tok in tokenize(content)]
        temp_df['TextID'].append(textid)
        temp_df['SentID'].append(sentid)
        temp_df['Indices'].append(saved_indices)
    
    df = pd.DataFrame(temp_df)
    return df

# Loss Function

In [None]:
def convert_label_to_similarity(normed_feature: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
    similarity_matrix = normed_feature @ normed_feature.transpose(1, 0)
    label_matrix = label.unsqueeze(1) == label.unsqueeze(0)

    positive_matrix = label_matrix.triu(diagonal=1)
    negative_matrix = label_matrix.logical_not().triu(diagonal=1)

    similarity_matrix = similarity_matrix.view(-1)
    positive_matrix = positive_matrix.view(-1)
    negative_matrix = negative_matrix.view(-1)
    return similarity_matrix[positive_matrix], similarity_matrix[negative_matrix]

In [None]:
class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss

# DataEvalUtils

In [None]:
def get_batch_labels(batch_addresses: List, spans_dataset: pd.DataFrame, indices_dataset: pd.DataFrame, span_indices: Tensor,
                     labels2ids):
    batch_size = len(batch_addresses[0])

    span_number = span_indices.shape[0]

    span_indices = span_indices.tolist()

    batch_labels = np.ones((batch_size, span_number))*-100
    batch_type = np.zeros((batch_size, span_number))

    for batch_idx, i in enumerate(zip(batch_addresses[0], batch_addresses[1].detach().tolist())):
        textid = i[0]
        sentid = i[1]

        labels_df = list(spans_dataset[spans_dataset['TextID']==textid][spans_dataset['SentID']==sentid]['Spans'])
        readable_spans = []
        for ent in labels_df:
            temp = ent[1:-1].split(', ')
            readable_spans.append([int(temp[0]), int(temp[1]),labels2ids[temp[2][1:-1]]])
        
        
        true_indices = indices_dataset[indices_dataset['TextID']==textid][indices_dataset['SentID']==sentid]['Indices'].iloc[0]
        max_ind = true_indices[-1][-1]
        for idx, spn in enumerate(span_indices):
            if spn[0] <= max_ind or spn[1] <= max_ind:
                batch_labels[batch_idx][idx] = 0
            else:
                break

        for span in readable_spans:
            temp_span_ids = []
            memory = []
            for idx, spn in enumerate(true_indices):
                if spn[0] in range(span[0], span[1]+1) and spn[1] in range(span[0], span[1]+1):
                    if temp_span_ids == []:
                        temp_span_ids = [idx, idx]
                        memory = [spn[0], spn[1]]
                    else:
                        temp_span_ids[-1] = idx
                        memory[-1] = spn[1]
            if temp_span_ids != []:
                try:
                    place = span_indices.index(temp_span_ids)
                    batch_labels[batch_idx][place] = span[2]
                    batch_type[batch_idx][place] = 1
                except:
                    pass
                
    return torch.from_numpy(batch_labels).type(torch.int64), torch.from_numpy(batch_type)

In [None]:
class Evaluator:

    def __init__(self, ids2labels):
        _ = ids2labels.pop(0)
        self.ids2labels = ids2labels
        self.num_types = len(self.ids2labels)

    def evaluate(self, y_true, y_pred):

        # y_true, y_pred передаются батчами/списками

        tp, fn, fp = 0, 0, 0
        sub_tp, sub_fn, sub_fp = [0] * self.num_types, [0] * self.num_types, [0] * self.num_types

        for gold_example, pred_example in zip(y_true, y_pred):
            gold_ners = set([(idx, int(ent)) for idx, ent in enumerate(gold_example) if ent != 0])
            pred_ners = set([(idx, int(ent)) for idx, ent in enumerate(pred_example) if ent != 0])

            tp += len(gold_ners & pred_ners)
            fn += len(gold_ners - pred_ners)
            fp += len(pred_ners - gold_ners)
            for i in range(self.num_types):
                sub_gm = set((idx, ent) for idx, ent in gold_ners if ent == i+1)
                sub_pm = set((idx, ent) for idx, ent in pred_ners if ent == i+1)
                sub_tp[i] += len(sub_gm & sub_pm)
                sub_fn[i] += len(sub_gm - sub_pm)
                sub_fp[i] += len(sub_pm - sub_gm)

        m_r = 0 if tp == 0 else float(tp) / (tp+fn)
        m_p = 0 if tp == 0 else float(tp) / (tp+fp)
        m_f1 = 0 if m_p == 0 else 2.0*m_r*m_p / (m_r+m_p)
        logger.info("Mention F1: {:.5f}%".format(m_f1 * 100))
        logger.info("Mention Recall: {:.5f}%".format(m_r * 100))
        logger.info("Mention Precision: {:.5f}%".format(m_p * 100))
        logger.info("****************SUB NER TYPES********************")
        f1_scores_list = []
        for i in range(self.num_types):
            if i+1 not in [24, 28, 29]:
                sub_r = 0 if sub_tp[i] == 0 else float(sub_tp[i]) / (sub_tp[i] + sub_fn[i])
                sub_p = 0 if sub_tp[i] == 0 else float(sub_tp[i]) / (sub_tp[i] + sub_fp[i])
                sub_f1 = 0 if sub_p == 0 else 2.0 * sub_r * sub_p / (sub_r + sub_p)
                f1_scores_list.append(sub_f1)
                logger.info("{} F1: {:.5f}%".format(self.ids2labels[i+1], sub_f1 * 100))
                logger.info("{} Recall: {:.5f}%".format(self.ids2labels[i+1], sub_r * 100))
                logger.info("{} Precision: {:.5f}%".format(self.ids2labels[i+1], sub_p * 100))
                logger.info(f'{sub_tp[i]}, {sub_fn[i]}, {sub_fp[i]}')
        summary_dict = {}
        summary_dict["Mention F1"] = m_f1
        summary_dict["Mention Recall"] = m_r
        summary_dict["Mention Precision"] = m_p

        summary_dict["Macro F1"] = sum([each for i, each in enumerate(f1_scores_list)]) \
            / float(self.num_types-3)
        logger.info("Macro F1: {:.5f}%".format(summary_dict["Macro F1"] * 100))
        return summary_dict["Macro F1"]

# SpanExtraction

### Utils

In [None]:
class ConfigurationError(Exception):
    """
    The exception raised by any AllenNLP object when it's misconfigured
    (e.g. missing properties, invalid properties, unknown properties).
    """

    def __init__(self, message: str):
        super().__init__()
        self.message = message

    def __str__(self):
        return self.message


In [None]:
def get_range_vector(size: int, device: int) -> Tensor:
    """
    Возвращает вектор диапазона требуемого размера, начинающийся с 0.
    """
    if device > -1:
        return cuda.LongTensor(
            size, device=device).fill_(1).cumsum(0) - 1
    else:
        return torch.arange(0, size, dtype=torch.long)

In [None]:
def get_device_of(tensor: Tensor) -> int:
    """
    Возвращает, на cuda или cpu требуемый тензор
    """
    if not tensor.is_cuda:
        return -1
    else:
        return tensor.get_device()

In [None]:
def bucket_values(distances: torch.Tensor,
                  num_identity_buckets: int=4,
                  num_total_buckets: int=10) -> torch.Tensor:
    """
    Кластеризует значения в `num_total_buckets`, num_identity_buckets из которых принимают
    одно значение, а не диапазон.
    По умолчанию:
    [0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+].
    Используется для кодирования длин спэга
    # Parameters
    distances : `torch.Tensor`, required.
        A Tensor of any size, to be bucketed.
    num_identity_buckets: `int`, optional (default = `4`).
        The number of identity buckets (those only holding a single value).
    num_total_buckets : `int`, (default = `10`)
        The total number of buckets to bucket values into.
    # Returns
    `torch.Tensor`
        A tensor of the same shape as the input, containing the indices of the buckets
        the values were placed in.
    """
    # Chunk the values into semi-logscale buckets using .floor().
    # This is a semi-logscale bucketing because we divide by log(2) after taking the log.
    # We do this to make the buckets more granular in the initial range, where we expect
    # most values to fall. We then add (num_identity_buckets - 1) because we want these indices
    # to start _after_ the fixed number of buckets which we specified would only hold single values.
    logspace_index = (distances.float().log() / math.log(2)).floor().long() + (
        num_identity_buckets - 1)
    # create a mask for values which will go into single number buckets (i.e not a range).
    use_identity_mask = (distances <= num_identity_buckets).long()
    use_buckets_mask = 1 + (-1 * use_identity_mask)
    # Use the original values if they are less than num_identity_buckets, otherwise
    # use the logspace indices.
    combined_index = use_identity_mask * distances + use_buckets_mask * logspace_index
    # Clamp to put anything > num_total_buckets into the final bucket.
    return combined_index.clamp(0, num_total_buckets - 1)

In [None]:
def flatten_and_batch_shift_indices(indices: torch.Tensor,
                                    sequence_length: int) -> torch.Tensor:
    """
    Вспомогательная функция для batched_index_select (см. далее)
    На вход функция получает "indices" размерности (batch_size, d_1, ..., d_n)
    приводит все к размерности: (batch_size, sequence_length, embedding_size)
    ```python
        indices = torch.ones([2,3], dtype=torch.long)
        # Sequence length of the target tensor.
        sequence_length = 10
        shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length)
        # Indices into the second element in the batch are correctly shifted
        # to take into account that the target tensor will be flattened before
        # the indices are applied.
        assert shifted_indices == [1, 1, 1, 11, 11, 11]
    ```
    # Parameters
    indices : `torch.LongTensor`, required.
    sequence_length : `int`, required.
        The length of the sequence the indices index into.
        This must be the second dimension of the tensor.
    # Returns
    offset_indices : `torch.LongTensor`
    """
    # Shape: (batch_size)
    if torch.max(indices) >= sequence_length or torch.min(indices) < 0:
        raise ConfigurationError(
            f"All elements in indices should be in range (0, {sequence_length - 1})"
        )
    
    offsets = (get_range_vector(indices.size(0), get_device_of(indices)) *
               sequence_length)
    for _ in range(len(indices.shape) - 1):
        offsets = offsets.unsqueeze(1)

    # Shape: (batch_size, d_1, ..., d_n)
    offset_indices = indices + offsets

    # Shape: (batch_size * d_1 * ... * d_n)
    offset_indices = offset_indices.reshape(-1)
    return offset_indices


In [None]:
def batched_index_select(
        target: torch.Tensor,
        indices: torch.LongTensor,
        flattened_indices: Optional[torch.LongTensor]=None, ) -> torch.Tensor:
    """
    На вход функция получает "indices" размерности (batch_size, d_1, ..., d_n). Они индексируются
    в размерность последовательности (dim 2). Размерность таргета: 
    (batch_size, sequence_length, embedding_size)
    Возвращает отобранные значения в таргете с опорой на полученные индексы,
    размера (batch_size, d_1, ..., d_n, embedding_size).
    
    # Parameters
    target : `torch.Tensor`, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : `torch.LongTensor`
        A tensor of shape (batch_size, ...), where each element is an index into the
        `sequence_length` dimension of the `target` tensor.
    flattened_indices : `Optional[torch.Tensor]`, optional (default = `None`)
        An optional tensor representing the result of calling `flatten_and_batch_shift_indices`
        on `indices`. This is helpful in the case that the indices can be flattened once and
        cached for many batch lookups.
    # Returns
    selected_targets : `torch.Tensor`
        A tensor with shape [indices.shape, target.size(-1)] representing the embedded indices
        extracted from the batch flattened target tensor.
    """
    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(indices,
                                                            target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.reshape(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.shape) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.reshape(*selected_shape)
    return selected_targets

In [None]:
def get_lengths_from_binary_sequence_mask(
        mask: torch.BoolTensor) -> torch.LongTensor:
    """
    Вычисление длины последовательности в каждом батче с помощьью бинарной маски
    # Parameters
    mask : `torch.BoolTensor`, required.
        A 2D binary mask of shape (batch_size, sequence_length) to
        calculate the per-batch sequence lengths from.
    # Returns
    `torch.LongTensor`
        A torch.LongTensor of shape (batch_size,) representing the lengths
        of the sequences in the batch.
    """
    return mask.sum(-1)

In [None]:
def batched_span_select(target: torch.Tensor,
                        spans: torch.LongTensor) -> torch.Tensor:
    """
    На вход получает спэны размерности (batch_size, num_spans, 2), 
    индексируется в размерность последовательности (dim 2) таргета, который
    представлен размером: (batch_size, sequence_length, embedding_size)
    Возвращает сегментированные спэны в таргете с учетом полученных индексов:
    Эмбеддинг спэна размерности (batch_size, num_spans, max_batch_span_width, embedding_size)
    # Parameters
    target : `torch.Tensor`, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : `torch.LongTensor`
        A 3 dimensional tensor of shape (batch_size, num_spans, 2) representing start and end
        indices (both inclusive) into the `sequence_length` dimension of the `target` tensor.
    # Returns
    span_embeddings : `torch.Tensor`
        A tensor with shape (batch_size, num_spans, max_batch_span_width, embedding_size)
        representing the embedded spans extracted from the batch flattened target tensor.
    span_mask: `torch.BoolTensor`
        A tensor with shape (batch_size, num_spans, max_batch_span_width) representing the mask on
        the returned span embeddings.
    """
    # both of shape (batch_size, num_spans, 1)
    span_starts, span_ends = spans.split(1, dim=-1)


    # shape (batch_size, num_spans, 1)
    # These span widths are off by 1, because the span ends are `inclusive`.
    span_widths = span_ends - span_starts

    # We need to know the maximum span width so we can
    # generate indices to extract the spans from the sequence tensor.
    # These indices will then get masked below, such that if the length
    # of a given span is smaller than the max, the rest of the values
    # are masked.
    max_batch_span_width = span_widths.max().item() + 1

    # Shape: (1, 1, max_batch_span_width)
    max_span_range_indices = get_range_vector(
                            max_batch_span_width, get_device_of(target)).reshape(1, 1, -1)
    # Shape: (batch_size, num_spans, max_batch_span_width)
    # This is a broadcasted comparison - for each span we are considering,
    # we are creating a range vector of size max_span_width, but masking values
    # which are greater than the actual length of the span.
    #
    # We're using <= here (and for the mask below) because the span ends are
    # inclusive, so we want to include indices which are equal to span_widths rather
    # than using it as a non-inclusive upper bound.
    span_mask = max_span_range_indices <= span_widths #(...).float()
    raw_span_indices = span_starts + max_span_range_indices
    # span_ends - max_span_range_indices

    # We also don't want to include span indices which greater than the sequence_length,
    # which happens because some spans near the end of the sequence
    # have a start index + max_batch_span_width > sequence_length, so we add this to the mask here.
    span_mask = (span_mask & (raw_span_indices < target.size(1)) &
                 (0 <= raw_span_indices)) # доп ограничение к оригиналу
    span_indices = raw_span_indices * span_mask 

    # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
    # flatten & batch at once
    span_embeddings = batched_index_select(target, span_indices)

    return span_embeddings, span_mask

In [None]:
def tiny_value_of_dtype(dtype: torch.dtype):
    """
    Возвращает относительно маленькое значение данного типа данных, 
    что применяется во избежание ошибок с вычислением (деление на 0)
    Поддерживает только типы с плавающей точкой
    """
    if not dtype.is_floating_point:
        raise TypeError("Only supports floating point dtypes.")
    if dtype == torch.float or dtype == torch.double:
        return 1e-13
    elif dtype == torch.half:
        return 1e-4
    else:
        raise TypeError("Does not support dtype " + str(dtype))

In [None]:
def info_value_of_dtype(dtype: torch.dtype):
    """
    Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
    """
    if dtype == torch.bool:
        raise TypeError("Does not support torch.bool")
    elif dtype.is_floating_point:
        return torch.finfo(dtype)
    else:
        return torch.iinfo(dtype)

def min_value_of_dtype(dtype: torch.dtype):
    """
    Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
    """
    return info_value_of_dtype(dtype).min


def max_value_of_dtype(dtype: torch.dtype):
    """
    Returns the maximum value of a given PyTorch data type. Does not allow torch.bool.
    """
    return info_value_of_dtype(dtype).max

In [None]:
def masked_max(
        vector: torch.Tensor,
        mask: torch.BoolTensor,
        dim: int,
        keepdim: bool=False, ) -> torch.Tensor:
    """
    Вычисление макс пулинга
    # Parameters
    vector : `torch.Tensor`
        The vector to calculate max, assume unmasked parts are already zeros
    mask : `torch.BoolTensor`
        The mask of the vector. It must be broadcastable with vector.
    dim : `int`
        The dimension to calculate max
    keepdim : `bool`
        Whether to keep dimension
    # Returns
    `torch.Tensor`
        A `torch.Tensor` of including the maximum values.
    """
    replaced_vector = vector.masked_fill(~mask,
                                         min_value_of_dtype(vector.dtype))
    max_value, _ = replaced_vector.max(dim=dim, keepdim=keepdim)
    return max_value

In [None]:
def weighted_sum(matrix: torch.Tensor,
                 attention: torch.Tensor) -> torch.Tensor:
    """
    На вход получает матрицу векторов и множество весов для рядов в этой матрице
    - вектор внимания. После этого возвращается взвешеная сумма рядов в матрице.
    """
    # We'll special-case a few settings here, where there are efficient (but poorly-named)
    # operations in pytorch that already do the computation we need.
    if attention.ndim == 2 and matrix.ndim == 3:
        return attention.unsqueeze(1).bmm(matrix).squeeze(1) # batch matrix-matrix product
    if attention.ndim == 3 and matrix.ndim == 3:
        return attention.bmm(matrix)
    if matrix.ndim - 1 < attention.ndim:
        expanded_size = list(matrix.shape)
        for i in range(attention.ndim - matrix.ndim + 1):
            matrix = matrix.unsqueeze(1)
            expanded_size.insert(i + 1, attention.size(i + 1))
        matrix = matrix.expand(*expanded_size)
    intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
    return intermediate.sum(dim=-2)

In [None]:
def masked_softmax(
        vector: torch.Tensor,
        mask: torch.BoolTensor,
        dim: int=-1,
        memory_efficient: bool=False, ) -> torch.Tensor:
    """
    F.softmax(vector) не применяется, так как некоторые элементы вектора могут быть
    замаскированы. Поэтому функция применяет операцию softmax только на незамаскированной
    части вектора.
    """
    if mask is None:
        result = F.softmax(vector, dim=dim)
    else:
        while mask.ndim < vector.ndim:
            mask = mask.unsqueeze(1)
        if not memory_efficient:
            # To limit numerical errors from large vector elements outside the mask, 
            # we zero these out.
            result = F.softmax(vector * mask, dim=dim)
            result = result * mask
            result = result / (result.sum(dim=dim, keepdim=True) +
                               tiny_value_of_dtype(result.dtype))
        else:
            masked_vector = vector.masked_fill(
                ~mask, min_value_of_dtype(vector.dtype))
            result = F.softmax(masked_vector, dim=dim)
    return result

In [None]:
def replace_masked_values(
    tensor: torch.Tensor, mask: torch.BoolTensor, replace_with: float
) -> torch.Tensor:
    """
    Replaces all masked values in `tensor` with `replace_with`.  `mask` must be broadcastable
    to the same shape as `tensor`. We require that `tensor.dim() == mask.dim()`, as otherwise we
    won't know which dimensions of the mask to unsqueeze.
    This just does `tensor.masked_fill()`, except the pytorch method fills in things with a mask
    value of 1, where we want the opposite.  You can do this in your own code with
    `tensor.masked_fill(~mask, replace_with)`.
    """
    if tensor.dim() != mask.dim():
        raise ConfigurationError(
            "tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())
        )
    return tensor.masked_fill(~mask, replace_with)

In [None]:
def attention(query, key, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)

    return p_attn


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


In [None]:
class TimeDistributed(nn.Module):
    """
    На вход получает данные размерности (batch_size, time_steps, [rest]) и Module,
    принимающий на вход данные размерности (batch_size, [rest])
    Модуль TimeDistributed меняет размер входных данных на (batch_size * time_steps, [rest]),
    применяет трансформации из Module и трансформирует размерность обратно.
    """

    def __init__(self, module):
        super().__init__()
        self._module = module

    def forward(self, *inputs, pass_through: List[str]=None, **kwargs):

        pass_through = pass_through or []

        # решейп вводной информации из inputs
        reshaped_inputs = [
            self._reshape_tensor(input_tensor) for input_tensor in inputs
        ]

        # Требуются данные для определения размера батча (batch_size) и time_steps.
        # Либо тензор из *inputs, либо из *kwargs
        some_input = None
        if inputs:
            some_input = inputs[-1]

        reshaped_kwargs = {}
        for key, value in kwargs.items():
            if isinstance(value, torch.Tensor) and key not in pass_through:
                if some_input is None:
                    some_input = value

                value = self._reshape_tensor(value)

            reshaped_kwargs[key] = value

        # применение модели к преобразованным данным
        reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs)

        if some_input is None:
            raise RuntimeError("No input tensor to time-distribute")

        # Приводим вывод к нужной размерности:
        # (batch_size, time_steps, **output_size)
        new_size = some_input.shape[:2] + reshaped_outputs.shape[1:]
        outputs = reshaped_outputs.contiguous().reshape(new_size)

        return outputs

    @staticmethod
    def _reshape_tensor(input_tensor):
        input_size = input_tensor.shape
        if len(input_size) <= 2:
            raise RuntimeError(f"No dimension to distribute: {input_size}")
        # Сведение batch_size and time_steps в единую ось, размерность:
        # (batch_size * time_steps, **input_size).
        squashed_shape = [-1] + list(input_size[2:])
        return input_tensor.reshape(*squashed_shape)

### Biaffine

In [None]:
class Biaffine(nn.Module):

    """
    На вход получает данные размерности (batch_size, seq_max_len, input_dim) - результаты применения BiLSTM
    Модуль вычисляет матрицу зависимостей для каждого токена с каждым, возвращает усредненные вектора
    биаффинного внимания размерности (batch_size, seq_max_len, dep_vec_dim)
    """

    def __init__(self, input_dim, dep_vec_dim):
        super().__init__()

        self.input_dim = input_dim
        self.dep_vec_dim = dep_vec_dim

        self.U_1 = Parameter(torch.Tensor(input_dim, dep_vec_dim, input_dim))
        self.U_2 = Parameter(torch.Tensor(2*input_dim, dep_vec_dim))
        self.bias = Parameter(torch.zeros(dep_vec_dim))

        nn.init.xavier_uniform_(self.U_1)
        nn.init.xavier_uniform_(self.U_2)
        nn.init.constant_(self.bias, 0.)

    def forward(self, h_forward, h_backward):

        seq_len = h_forward.shape[1]
        batch_size = h_forward.shape[0]

        #Hf.T*U1*Hb # U1 - h*r*h, h - Hf/Hb dim, r - dep_vec_dim
        # batch x seq_len x seq_len x dep_vec_dim
        left_part= torch.einsum('bxi,irj,byj->bxyr', h_forward, self.U_1, h_backward)
        
        # (Hf⊕Hb).T*U2 # U2 - 2h*r
        hf = torch.unsqueeze(h_forward, dim=2)
        hf = torch.tile(hf, (1, 1, h_backward.shape[-2], 1))
        hb = torch.unsqueeze(h_backward, dim=1)
        hb = torch.tile(hb, (1, h_forward.shape[-2], 1, 1))

        concat_h = torch.concat((hf, hb), dim=-1)
        right_part = torch.einsum("bxyd,do->bxyo", concat_h, self.U_2)

        # batch x seq_len x seq_len x dep_vec_dim
        biaff_matrix = left_part + right_part + self.bias

        # batch x seq_len x 1 x dep_vec_dim
        dep_vectors = nn.AvgPool3d((1, seq_len, 1))(biaff_matrix)
        # batch x seq_len x dep_vec_dim
        dep_vectors = dep_vectors.view(batch_size, seq_len, self.dep_vec_dim)        

        return dep_vectors

### Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_dim: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.query_weight = nn.Linear(n_dim, n_dim)
        self.key_weight = nn.Linear(n_dim, n_dim, bias=False)
        self.value_weight = nn.Linear(n_dim, n_dim)
        self.linear = nn.Linear(n_dim, n_dim)

    def forward(
        self,
        sequence: Tensor,
        mask: Optional[Tensor]=None
    ):
        query = self.query_weight(sequence)
        key = self.key_weight(sequence)
        value = self.value_weight(sequence)

        wv, qk = self.compute_attention(query, key, value, mask)
        return self.linear(wv), qk

    def compute_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_heads) ** -0.25
        q = q.view(*q.shape[:2], self.n_heads, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_heads, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_heads, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()     


### AGGCN modules


In [None]:
class GraphConvLayer(nn.Module):
    """ A GCN module operated on dependency graphs. """

    def __init__(self, tree_dropout, mem_dim, layers):
        super(GraphConvLayer, self).__init__()
        self.mem_dim = mem_dim
        self.layers = layers
        self.head_dim = self.mem_dim // self.layers
        self.gcn_drop = nn.Dropout(tree_dropout)

        # linear transformation
        self.linear_output = nn.Linear(self.mem_dim, self.mem_dim)

        # dcgcn block
        self.weight_list = nn.ModuleList()
        for i in range(self.layers):
            self.weight_list.append(nn.Linear((self.mem_dim + self.head_dim * i), self.head_dim))


    def forward(self, adj, gcn_inputs):
        # gcn layer
        denom = adj.sum(2).unsqueeze(2) + 1

        outputs = gcn_inputs
        cache_list = [outputs]
        outputs = outputs.type(torch.float64)
        output_list = []
        for l in range(self.layers):
            Ax = adj.bmm(outputs)
            Ax = Ax.type(torch.float32)
            outputs = outputs.type(torch.float32)
            AxW = self.weight_list[l](Ax)
            AxW = AxW + self.weight_list[l](outputs)  # self loop
            AxW = AxW / denom
            
            gAxW = F.relu(AxW)
            cache_list.append(gAxW)
            outputs = torch.cat(cache_list, dim=2)
            output_list.append(self.gcn_drop(gAxW))

        gcn_outputs = torch.cat(output_list, dim=2)
        gcn_outputs = gcn_outputs + gcn_inputs
        
        gcn_outputs = gcn_outputs.type(torch.float32)
        out = self.linear_output(gcn_outputs)
        return out

In [None]:
class MultiGraphConvLayer(nn.Module):
    """ A GCN module operated on dependency graphs. """

    def __init__(self, tree_dropout, mem_dim, layers, heads):
        super(MultiGraphConvLayer, self).__init__()
        self.mem_dim = mem_dim
        self.layers = layers
        self.head_dim = self.mem_dim // self.layers
        self.heads = heads
        self.gcn_drop = nn.Dropout(tree_dropout)

        # dcgcn layer
        self.Linear = nn.Linear(self.mem_dim * self.heads, self.mem_dim)
        self.weight_list = nn.ModuleList()

        for i in range(self.heads):
            for j in range(self.layers):
                self.weight_list.append(nn.Linear(self.mem_dim + self.head_dim * j, self.head_dim))


    def forward(self, adj_list, gcn_inputs):

        multi_head_list = []
        for i in range(self.heads):
            adj = adj_list[i]
            denom = adj.sum(2).unsqueeze(2) + 1
            outputs = gcn_inputs
            cache_list = [outputs]
            output_list = []
            for l in range(self.layers):
                index = i * self.layers + l
                Ax = adj.bmm(outputs)
                AxW = self.weight_list[index](Ax)
                AxW = AxW + self.weight_list[index](outputs)  # self loop
                AxW = AxW / denom
                gAxW = F.relu(AxW)
                cache_list.append(gAxW)
                outputs = torch.cat(cache_list, dim=2)
                output_list.append(self.gcn_drop(gAxW))

            gcn_ouputs = torch.cat(output_list, dim=2)
            gcn_ouputs = gcn_ouputs + gcn_inputs

            multi_head_list.append(gcn_ouputs)

        final_output = torch.cat(multi_head_list, dim=2)
        out = self.Linear(final_output)

        return out

In [None]:
class GraphMultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super(GraphMultiHeadAttention, self).__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 2)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)

        nbatches = query.size(0)

        query, key = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linears, (query, key))]
        attn = attention(query, key, mask=mask, dropout=self.dropout)

        return attn

In [None]:
class AGGCN(nn.Module):
    def __init__(self,
                 span_emb_dim: int,
                 feature_dim: int,
                 tree_prop: int = 1,
                 tree_dropout: float=0.0, 
                 aggcn_heads: int=4,
                 aggcn_sublayer_first: int=2,
                 aggcn_sublayer_second: int=4):
        super(AGGCN, self).__init__()

        self.in_dim = span_emb_dim
        self.mem_dim = span_emb_dim

        self.input_W_G = nn.Linear(self.in_dim, self.mem_dim)

        self.num_layers = tree_prop

        self.layers = nn.ModuleList()

        self.heads = aggcn_heads
        self.sublayer_first = aggcn_sublayer_first
        self.sublayer_second = aggcn_sublayer_second

        # gcn layer
        for i in range(self.num_layers):
            if i == 0:
                self.layers.append(GraphConvLayer(tree_dropout, self.mem_dim, self.sublayer_first))
                self.layers.append(GraphConvLayer(tree_dropout, self.mem_dim, self.sublayer_second))
            else:
                self.layers.append(MultiGraphConvLayer(tree_dropout, self.mem_dim, self.sublayer_first, self.heads))
                self.layers.append(MultiGraphConvLayer(tree_dropout, self.mem_dim, self.sublayer_second, self.heads))

        self.aggregate_W = nn.Linear(len(self.layers) * self.mem_dim, self.mem_dim)

        self.attn = GraphMultiHeadAttention(self.heads, self.mem_dim)

        # mlp output layer
        in_dim = span_emb_dim
        mlp_layers = [nn.Linear(in_dim, feature_dim), nn.ReLU()]
        self.out_mlp = nn.Sequential(*mlp_layers)

    # adj: (batch, sequence, sequence)
    # text_embeddings: (batch, sequence, emb_dim)
    # text_mask: (batch, sequence)
    def forward(self, adj, text_embeddings, text_mask):

        gcn_inputs = self.input_W_G(text_embeddings)
        text_mask = text_mask.unsqueeze(-2)
        layer_list = []
        outputs = gcn_inputs
        mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2)
        for i in range(len(self.layers)):
            if i < 2:
                outputs = self.layers[i](adj, outputs)
                layer_list.append(outputs)
            else:
                attn_tensor = self.attn(outputs, outputs, text_mask)
                attn_adj_list = [attn_adj.squeeze(1) for attn_adj in torch.split(attn_tensor, 1, dim=1)]
                outputs = self.layers[i](attn_adj_list, outputs)
                layer_list.append(outputs)

        aggregate_out = torch.cat(layer_list, dim=2)
        dcgcn_output = self.aggregate_W(aggregate_out)

        outputs = self.out_mlp(dcgcn_output)
        return outputs

### Extractors

In [None]:
class SpanExtractor(nn.Module):
    """
    Many NLP models deal with representations of spans inside a sentence.
    SpanExtractors define methods for extracting and representing spans
    from a sentence.
    SpanExtractors take a sequence tensor of shape (batch_size, timesteps, embedding_dim)
    and indices of shape (batch_size, num_spans, 2) and return a tensor of
    shape (batch_size, num_spans, ...), forming some representation of the
    spans.
    """

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor=None,
            span_indices_mask: torch.BoolTensor=None, ):
        """
        Given a sequence tensor, extract spans and return representations of
        them. Span representation can be computed in many different ways,
        such as concatenation of the start and end spans, attention over the
        vectors contained inside the span, etc.
        # Parameters
        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.
        # Returns
        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        raise NotImplementedError

    def get_input_dim(self) -> int:
        """
        Returns the expected final dimension of the `sequence_tensor`.
        """
        raise NotImplementedError

    def get_output_dim(self) -> int:
        """
        Returns the expected final dimension of the returned span representation.
        """
        raise NotImplementedError

In [None]:
class MaxPoolingSpanExtractor(SpanExtractor):
    """
    Представляет спэны с помощью макс-пулинга.
    При заданном спэне x_i, ..., x_j, где i, j - начало и конец спэна,
    каждое измерение d результирующего спэна вычисляется как s_d = max(x_id, ..., x_jd)
    Registered as a `SpanExtractor` with name "max_pooling".
    # Parameters
    input_dim : `int`, required.
        The final dimension of the `sequence_tensor`.
    num_width_embeddings : `int`, optional (default = `None`).
        Specifies the number of buckets to use when representing
        span width features.
    span_width_embedding_dim : `int`, optional (default = `None`).
        The embedding size for the span_width features.
    bucket_widths : `bool`, optional (default = `False`).
        Whether to bucket the span widths into log-space buckets. If `False`,
        the raw span widths are used.
    # Returns
    span_embeddings : `torch.FloatTensor`.
    A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
    where `embedded_span_size` depends on the way spans are represented.
    """

    def __init__(
            self,
            input_dim: int,
            num_width_embeddings: int=None,
            span_width_embedding_dim: int=None,
            bucket_widths: bool=False, ) -> None:
        super().__init__()
        
        self._input_dim = input_dim
        self._num_width_embeddings = num_width_embeddings
        self._bucket_widths = bucket_widths

        if num_width_embeddings is not None and span_width_embedding_dim is not None:
            self._span_width_embedding = nn.Embedding(
                                         num_embeddings=num_width_embeddings,
                                         embedding_dim=span_width_embedding_dim)
        elif num_width_embeddings is not None or span_width_embedding_dim is not None:
            raise ConfigurationError(
                "To use a span width embedding representation, you must"
                "specify both num_width_embeddings and span_width_embedding_dim."
            )

    def get_input_dim(self) -> int:
        return self._input_dim

    def get_output_dim(self) -> int:
        if self._span_width_embedding is not None:
            return self._input_dim + self._span_width_embedding.get_output_dim(
            )
        return self._input_dim

    def _embed_spans(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor=None,
            span_indices_mask: torch.BoolTensor=None, ) -> torch.FloatTensor:

        if sequence_tensor.size(-1) != self._input_dim:
            raise ValueError(
                f"Dimension mismatch expected ({sequence_tensor.size(-1)}) "
                f"received ({self._input_dim}).")
        if sequence_tensor.shape[1] <= span_indices.max() or span_indices.min(
        ) < 0:
            raise IndexError(
                f"Span index out of range, max index ({span_indices.max()}) "
                f"or min index ({span_indices.min()}) "
                f"not valid for sequence of length ({sequence_tensor.shape[1]})."
            )

        if (span_indices[:, :, 0] > span_indices[:, :, 1]).any():
            raise IndexError("Span start above span end", )


        # Calculate the maximum sequence length for each element in batch.
        # If span_end indices are above these length, we adjust the indices in adapted_span_indices
        # проверка: для каждого элемента в батче вычисляется длина последовательности
        # если индекс конца спэна выше этой длины, то индексы адаптируются
        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = get_lengths_from_binary_sequence_mask(
                sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = torch.ones_like(
                sequence_tensor[:, 0, 0],
                dtype=torch.long) * sequence_tensor.size(1)

        adapted_span_indices = torch.tensor(
            span_indices, device=span_indices.device)

        # простой циклический проход по батчам, если текущий конечный индекс больше или равен
        # длине последовательности, то заменяем на макс_длину минус один
        for b in range(sequence_lengths.shape[0]):
            adapted_span_indices[b, :, 1][adapted_span_indices[b, :, 1] >=
                                          sequence_lengths[b]] = (
                                              sequence_lengths[b] - 1)

        # Raise Error if span indices were completely masked by sequence mask.
        # We only adjust span_end to the last valid index, so if span_end is below span_start,
        # both were above the max index:
        if (adapted_span_indices[:, :, 0] > adapted_span_indices[:, :, 1]
            ).any():
            raise IndexError(
                "Span indices were masked out entirely by sequence mask", )

        # span_vals <- (batch x num_spans x max_span_length x dim)
        span_vals, span_mask = batched_span_select(sequence_tensor,
                                                   adapted_span_indices)

        # The application of masked_max requires a mask of the same shape as span_vals
        # We repeat the mask along the last dimension (embedding dimension)
        repeat_dim = len(span_vals.shape) - 1
        repeat_idx = [1] * (repeat_dim) + [span_vals.shape[-1]]

        # Shape: (batch x num_spans x max_span_length x dim)
        # ext_span_mask True for values in span, False for masked out values
        ext_span_mask = span_mask.unsqueeze(repeat_dim).repeat(repeat_idx)

        # Shape: (batch x num_spans x embedding_dim)
        max_output = masked_max(span_vals, ext_span_mask, dim=-2)

        return max_output

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor=None,
            span_indices_mask: torch.BoolTensor=None, ):
        """
        Функция для извлечение спэнов, получения семантического эмбеддинга и конкатенации
        с эмбеддингом по длине
        # Parameters
        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.
        # Returns
        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        # shape (batch_size, num_spans, embedding_dim)
        span_embeddings = self._embed_spans(sequence_tensor, span_indices,
                                            sequence_mask, span_indices_mask)
        if self._span_width_embedding is not None:
            # width = end_index - start_index + 1 since `SpanField` use inclusive indices.
            # But here we do not add 1 because we often initiate the span width
            # embedding matrix with `num_width_embeddings = max_span_width`
            # shape (batch_size, num_spans)
            widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

            if self._bucket_widths:
                widths_minus_one = bucket_values(
                    widths_minus_one,
                    num_total_buckets=self._num_width_embeddings)  # type: ignore

            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                widths_minus_one)
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            # Here we are masking the spans which were originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1)

        return span_embeddings

In [None]:
class SelfAttentiveSpanExtractor(SpanExtractor):
    """
    Вычисляет представления спэнов с помощью генерации оценки внимания для каждого слова
    в тексте. Представления спэнов вычисляются с учетом этих оценок нормализацией полученных 
    значений для слов внутри спэна.
    Registered as a `SpanExtractor` with name "self_attentive".
    # Parameters
    input_dim : `int`, required.
        The final dimension of the `sequence_tensor`.
    num_width_embeddings : `int`, optional (default = `None`).
        Specifies the number of buckets to use when representing
        span width features.
    span_width_embedding_dim : `int`, optional (default = `None`).
        The embedding size for the span_width features.
    bucket_widths : `bool`, optional (default = `False`).
        Whether to bucket the span widths into log-space buckets. If `False`,
        the raw span widths are used.
    # Returns
    attended_text_embeddings : `torch.FloatTensor`.
        A tensor of shape (batch_size, num_spans, input_dim), which each span representation
        is formed by locally normalising a global attention over the sequence. The only way
        in which the attention distribution differs over different spans is in the set of words
        over which they are normalized.
    """

    def __init__(
            self,
            input_dim: int,
            reduced_dim: int,
            num_width_embeddings: int=None,
            span_width_embedding_dim: int=None,
            bucket_widths: bool=False, ) -> None:
        super().__init__()

        self._input_dim = input_dim
        self._num_width_embeddings = num_width_embeddings
        self._bucket_widths = bucket_widths
        self._heads = 4
        
        if num_width_embeddings is not None and span_width_embedding_dim is not None:
            self._span_width_embedding = nn.Embedding(
                                         num_embeddings=num_width_embeddings,
                                         embedding_dim=span_width_embedding_dim)
        elif num_width_embeddings is not None or span_width_embedding_dim is not None:
            raise ConfigurationError(
                "To use a span width embedding representation, you must"
                "specify both num_width_embeddings and span_width_embedding_dim."
            )

        self.attn = MultiHeadAttention(self._input_dim, self._heads)
        self.attn_dropout = nn.Dropout(0.2)
        self.dim_reducer = nn.Linear(self._input_dim+self._input_dim, reduced_dim+1)
        #self.dim_reducer = nn.LSTM(self._input_dim, reduced_dim+1,
        #                           num_layers=1, bidirectional=False, batch_first=True)
    
    def _embed_spans(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor=None,
            span_indices_mask: torch.BoolTensor=None, ) -> torch.FloatTensor:
        
        attention_output, _ = self.attn(sequence_tensor)
        attention_output = self.attn_dropout(attention_output)
        
        # shape (batch_size, sequence_length, embedding_dim + 1)
        concat_tensor = torch.cat([sequence_tensor, attention_output],
                                  -1)
        concat_tensor = sequence_tensor + attention_output
        #reduced_tensor, (_, _) = self.dim_reducer(concat_tensor)

        concat_output, span_mask = batched_span_select(reduced_tensor,
                                                       span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = masked_softmax(span_attention_logits,
                                                span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = weighted_sum(span_embeddings,
                                                span_attention_weights)

        return attended_text_embeddings

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.BoolTensor=None,
            span_indices_mask: torch.BoolTensor=None, ):
        """
        Функция для извлечение спэнов, получения семантического эмбеддинга и конкатенации
        с эмбеддингом по длине
        # Parameters
        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.
        # Returns
        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        # shape (batch_size, num_spans, embedding_dim)
        span_embeddings = self._embed_spans(sequence_tensor, span_indices,
                                            sequence_mask, span_indices_mask)
        if self._span_width_embedding is not None:
            # width = end_index - start_index + 1 since `SpanField` use inclusive indices.
            # But here we do not add 1 beacuse we often initiate the span width
            # embedding matrix with `num_width_embeddings = max_span_width`
            # shape (batch_size, num_spans)
            widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

            if self._bucket_widths:
                widths_minus_one = bucket_values(
                    widths_minus_one,
                    num_total_buckets=self._num_width_embeddings)  # type: ignore

            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                widths_minus_one)
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            # Here we are masking the spans which were originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1)
        return span_embeddings

In [None]:
# биаффинный аттеншен c вниманием

class SelfBiaffineSpanExtractor(SpanExtractor):
    
    def __init__(
            self,
            input_dim: int,
            reduced_dim: int,
            num_width_embeddings: int=None,
            span_width_embedding_dim: int=None,
            bucket_widths: bool=False, 
            use_gcn: bool=True,) -> None:
        super().__init__()

        self._input_dim = input_dim
        self._num_width_embeddings = num_width_embeddings
        self._bucket_widths = bucket_widths
        self._use_gcn = use_gcn
        self._span_width_embedding = None
        self._heads = 4
        
        if num_width_embeddings is not None and span_width_embedding_dim is not None:
            self._span_width_embedding = nn.Embedding(
                                         num_embeddings=num_width_embeddings,
                                         embedding_dim=span_width_embedding_dim)
        elif num_width_embeddings is not None or span_width_embedding_dim is not None:
            raise ConfigurationError(
                "To use a span width embedding representation, you must"
                "specify both num_width_embeddings and span_width_embedding_dim."
            )
        self.attn = MultiHeadAttention(self._input_dim, self._heads)

        self.lstm_dim = self._input_dim// 2 #768/2 обычно
        # batch_first - (batch, seq, 2*dim)
        self.bilstm = nn.LSTM(self._input_dim, self.lstm_dim, 
                         num_layers=1, bidirectional=True, batch_first=True)
        #self.dropout = nn.Dropout(0.3)
        
        self.dep_vec_dim = 256        
        self.biaffine = Biaffine(self.lstm_dim, self.dep_vec_dim)

        self.gcn_dim = 0
        if self._use_gcn:
            self.gcn_dim = 256
            self.graph_module = AGGCN(self._input_dim, self.gcn_dim,
                                tree_prop= 1,
                                tree_dropout=0.2, 
                                aggcn_heads=4,
                                aggcn_sublayer_first=2,
                                aggcn_sublayer_second=4)

        # embedding_dim+1
        self.dim_reducer = nn.Linear(self._input_dim+self.dep_vec_dim+self.gcn_dim+self._input_dim, reduced_dim+1)
    
    def _embed_spans(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            adj_matrix: torch.FloatTensor=None,
            sequence_mask:  torch.FloatTensor=None,
            span_indices_mask: torch.BoolTensor=None) -> torch.FloatTensor:
        
        attention_output, _ = self.attn(sequence_tensor)

        output, _ = self.bilstm(sequence_tensor)
        #output = self.dropout(F.leaky_relu(output))
        h_forward = output[:, :, :self.lstm_dim]
        h_backward = output[:, :, self.lstm_dim:]

        dep_output = self.biaffine(h_forward, h_backward)
        if self._use_gcn:
            graph_output = self.graph_module(adj_matrix, sequence_tensor, sequence_mask)
            concat_tensor = torch.cat((sequence_tensor, dep_output, graph_output, attention_output), -1)
        else:
            # batch x seq_len x (embedding_dim + 1) + dep_vec_dim
            concat_tensor = torch.cat((sequence_tensor, dep_output, attention_output), -1)

        reduced_tensor = self.dim_reducer(concat_tensor)
        
        concat_output, span_mask = batched_span_select(reduced_tensor,
                                                       span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]


        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = masked_softmax(span_attention_logits,
                                                span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim) # почему уменьшилась размерность

        attended_text_embeddings = weighted_sum(span_embeddings,
                                                span_attention_weights)
        return attended_text_embeddings

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            adj_matrix: torch.FloatTensor=None,
            sequence_mask: torch.FloatTensor=None,
            span_indices_mask: torch.BoolTensor=None):
        """
        Функция для извлечение спэнов, получения семантического эмбеддинга и конкатенации
        с эмбеддингом по длине
        # Parameters
        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.FloatTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.
        # Returns
        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        # shape (batch_size, num_spans, embedding_dim)
        span_embeddings = self._embed_spans(sequence_tensor, span_indices, adj_matrix,
                                            sequence_mask, span_indices_mask)
        if self._span_width_embedding is not None:
            # width = end_index - start_index + 1 since `SpanField` use inclusive indices.
            # But here we do not add 1 beacuse we often initiate the span width
            # embedding matrix with `num_width_embeddings = max_span_width`
            # shape (batch_size, num_spans)
            widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

            if self._bucket_widths:
                widths_minus_one = bucket_values(
                    widths_minus_one,
                    num_total_buckets=self._num_width_embeddings)  # type: ignore

            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                widths_minus_one)
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            # Here we are masking the spans which were originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1)
        return span_embeddings

In [None]:
# биаффинный аттеншен

class BiaffineSpanExtractor(SpanExtractor):
    """
    Дополнительно кодирует каждое слово с помощью биаффинного внимания
    В результате вектора зависимостей конкатенируются с семантическими и с 
    помощью макс пулинга получается представление спэнов.
    Registered as a `SpanExtractor` with name "self_attentive".
    # Parameters
    input_dim : `int`, required.
        The final dimension of the `sequence_tensor`.
    num_width_embeddings : `int`, optional (default = `None`).
        Specifies the number of buckets to use when representing
        span width features.
    span_width_embedding_dim : `int`, optional (default = `None`).
        The embedding size for the span_width features.
    bucket_widths : `bool`, optional (default = `False`).
        Whether to bucket the span widths into log-space buckets. If `False`,
        the raw span widths are used.
    # Returns
    attended_text_embeddings : `torch.FloatTensor`.
        A tensor of shape (batch_size, num_spans, input_dim), which each span representation
        is formed by locally normalising a global attention over the sequence. The only way
        in which the attention distribution differs over different spans is in the set of words
        over which they are normalized.
    """

    def __init__(
            self,
            input_dim: int,
            reduced_dim: int,
            num_width_embeddings: int=None,
            span_width_embedding_dim: int=None,
            bucket_widths: bool=False, 
            use_gcn: bool=True,) -> None:
        super().__init__()

        self._input_dim = input_dim
        self._num_width_embeddings = num_width_embeddings
        self._bucket_widths = bucket_widths
        self._use_gcn = use_gcn
        self._span_width_embedding = None
        
        if num_width_embeddings is not None and span_width_embedding_dim is not None:
            self._span_width_embedding = nn.Embedding(
                                         num_embeddings=num_width_embeddings,
                                         embedding_dim=span_width_embedding_dim)
        elif num_width_embeddings is not None or span_width_embedding_dim is not None:
            raise ConfigurationError(
                "To use a span width embedding representation, you must"
                "specify both num_width_embeddings and span_width_embedding_dim."
            )
        self.lstm_dim = self._input_dim// 2 #768/2 обычно
        # batch_first - (batch, seq, 2*dim)
        self.bilstm = nn.LSTM(self._input_dim, self.lstm_dim, 
                         num_layers=1, bidirectional=True, batch_first=True)
        
        self.dep_vec_dim = 256        
        self.biaffine = Biaffine(self.lstm_dim, self.dep_vec_dim)

        self.gcn_dim = 0
        if self._use_gcn:
            self.gcn_dim = 256
            self.graph_module = AGGCN(self._input_dim, self.gcn_dim,
                                tree_prop= 1,
                                tree_dropout=0.2, 
                                aggcn_heads=4,
                                aggcn_sublayer_first=2,
                                aggcn_sublayer_second=4)

        # embedding_dim+1
        self.dim_reducer = nn.Linear(768+256+self.gcn_dim, reduced_dim+1)
    
    def _embed_spans(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            adj_matrix: torch.FloatTensor=None,
            sequence_mask:  torch.FloatTensor=None,
            span_indices_mask: torch.BoolTensor=None) -> torch.FloatTensor:
        
        output, _ = self.bilstm(sequence_tensor)
        #output = self.dropout(F.leaky_relu(output))
        h_forward = output[:, :, :self.lstm_dim]
        h_backward = output[:, :, self.lstm_dim:]

        dep_output = self.biaffine(h_forward, h_backward)
        if self._use_gcn:
            graph_output = self.graph_module(adj_matrix, sequence_tensor, sequence_mask)
            concat_tensor = torch.cat((sequence_tensor, dep_output, graph_output), -1)
        else:
            # batch x seq_len x (embedding_dim + 1) + dep_vec_dim
            concat_tensor = torch.cat((sequence_tensor, dep_output), -1)

        reduced_tensor = self.dim_reducer(concat_tensor)
        
        concat_output, span_mask = batched_span_select(reduced_tensor,
                                                       span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]


        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = masked_softmax(span_attention_logits,
                                                span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim) # почему уменьшилась размерность

        attended_text_embeddings = weighted_sum(span_embeddings,
                                                span_attention_weights)
        return attended_text_embeddings

    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            adj_matrix: torch.FloatTensor=None,
            sequence_mask: torch.FloatTensor=None,
            span_indices_mask: torch.BoolTensor=None):
        """
        Функция для извлечение спэнов, получения семантического эмбеддинга и конкатенации
        с эмбеддингом по длине
        # Parameters
        sequence_tensor : `torch.FloatTensor`, required.
            A tensor of shape (batch_size, sequence_length, embedding_size)
            representing an embedded sequence of words.
        span_indices : `torch.LongTensor`, required.
            A tensor of shape `(batch_size, num_spans, 2)`, where the last
            dimension represents the inclusive start and end indices of the
            span to be extracted from the `sequence_tensor`.
        sequence_mask : `torch.FloatTensor`, optional (default = `None`).
            A tensor of shape (batch_size, sequence_length) representing padded
            elements of the sequence.
        span_indices_mask : `torch.BoolTensor`, optional (default = `None`).
            A tensor of shape (batch_size, num_spans) representing the valid
            spans in the `indices` tensor. This mask is optional because
            sometimes it's easier to worry about masking after calling this
            function, rather than passing a mask directly.
        # Returns
        A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
        where `embedded_span_size` depends on the way spans are represented.
        """
        # shape (batch_size, num_spans, embedding_dim)
        span_embeddings = self._embed_spans(sequence_tensor, span_indices, adj_matrix,
                                            sequence_mask, span_indices_mask)
        if self._span_width_embedding is not None:
            # width = end_index - start_index + 1 since `SpanField` use inclusive indices.
            # But here we do not add 1 beacuse we often initiate the span width
            # embedding matrix with `num_width_embeddings = max_span_width`
            # shape (batch_size, num_spans)
            widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

            if self._bucket_widths:
                widths_minus_one = bucket_values(
                    widths_minus_one,
                    num_total_buckets=self._num_width_embeddings)  # type: ignore

            # Embed the span widths and concatenate to the rest of the representations.
            span_width_embeddings = self._span_width_embedding(
                widths_minus_one)
            span_embeddings = torch.cat(
                [span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            # Here we are masking the spans which were originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1)
        return span_embeddings

# NER Classifier

In [None]:
class NERTagger(nn.Module):

    def __init__(self,
                 input_dim: int,
                 num_labels: int,
                 ff_dropout: float=0.4) -> None:
        super(NERTagger, self).__init__()

        self._num_labels = num_labels
        feed_forward = nn.Sequential(nn.Linear(input_dim, 256), 
                                    nn.ReLU(),
                                    nn.Dropout(ff_dropout))
        self._ner_scorer = torch.nn.Sequential(
                           TimeDistributed(feed_forward),
                           TimeDistributed(nn.Linear(256, self._num_labels-1)))
    def forward(self,
                spans: torch.IntTensor,
                span_mask: torch.IntTensor,
                span_embeddings: torch.IntTensor,
                ner_labels: torch.IntTensor = None,
                previous_step_output: Dict[str, Any] = None) -> Dict[str, torch.Tensor]:


        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        ner_scores = self._ner_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = replace_masked_values(ner_scores, mask, -1e20)
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        if previous_step_output is not None and "predicted_span" in previous_step_output and not self.training:
            dummy_scores.masked_fill_(previous_step_output["predicted_span"].bool().unsqueeze(-1), -1e20)
            dummy_scores.masked_fill_((1-previous_step_output["predicted_span"]).bool().unsqueeze(-1), 1e20)

        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        if previous_step_output is not None and "predicted_seq_span" in previous_step_output and not self.training:
            for row_idx, all_spans in enumerate(spans):
                pred_spans = previous_step_output["predicted_seq_span"][row_idx]
                pred_spans = all_spans.new_tensor(pred_spans)
                for col_idx, span in enumerate(all_spans):
                    if span_mask[row_idx][col_idx] == 0:
                        continue
                    bFind = False
                    for pred_span in pred_spans:
                        if span[0] == pred_span[0] and span[1] == pred_span[1]:
                            bFind = True
                            break
                    if bFind:
                        # if find, use the ner scores, set dummy to a big negative
                        ner_scores[row_idx, col_idx, 0] = -1e20
                    else:
                        # if not find, use the previous step, set dummy to a big positive
                        ner_scores[row_idx, col_idx, 0] = 1e20

        return ner_scores

# Main NNER Class

In [None]:
class NNERModel(nn.Module):
    def __init__(self,
                 num_labels,
                 dropout_rate=0.4,
                 max_seq_len = 284,
                 max_span_len=None,
                 extractor_type='biaffine'):
        ## Add Extractor Class
        super(NNERModel, self).__init__()
        
        self.max_span_len = max_span_len
        self.num_labels = num_labels
        self.extractor_type = extractor_type

        # init span_indices
        if self.max_span_len == None or self.max_span_len > max_seq_len:
            self.max_span_len = 50
        elif self.max_span_len <= 0:
            self.max_span_len = 50
        self.max_span_len = int(self.max_span_len)

        self.triangle_mask = (torch.triu(
                        torch.ones(max_seq_len, max_seq_len), diagonal=0) - torch.triu(
                        torch.ones(max_seq_len, max_seq_len), diagonal=max_span_len)).bool()
        


        self.encoder = BertModel.from_pretrained('DeepPavlov/rubert-base-cased')
        encoder_hidden_state = self.encoder.config.hidden_size

        self.encoder_dropout = nn.Dropout(dropout_rate)

        reduced_dim = 512
        width_dim = 128
        if extractor_type=='biaffine':
            self.extractor = BiaffineSpanExtractor(input_dim=encoder_hidden_state,
                                                reduced_dim=reduced_dim,
                                                num_width_embeddings = self.max_span_len,
                                                span_width_embedding_dim=width_dim,
                                                use_gcn = False)
        elif extractor_type=='attention':
            self.extractor = SelfAttentiveSpanExtractor(input_dim=encoder_hidden_state,
                                                        reduced_dim=encoder_hidden_state,
                                                        num_width_embeddings = self.max_span_len,
                                                        span_width_embedding_dim=width_dim,
                                                        )
        elif extractor_type == 'selfbiaffine':
            self.extractor = SelfBiaffineSpanExtractor(input_dim=encoder_hidden_state,
                                                reduced_dim=reduced_dim,
                                                num_width_embeddings = self.max_span_len,
                                                span_width_embedding_dim=width_dim,
                                                use_gcn = False)
        elif extractor_type=='maxpooling':
            pass
        
        self.pruner = nn.Linear(reduced_dim+width_dim, 1)
        
        self.classifier = NERTagger(reduced_dim+width_dim, self.num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        adj_matrix=None
        ):

        batch_size, max_seq_len = input_ids.size(0), input_ids.size(1)

        embedded_text_input = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        embedded_text_input = embedded_text_input.last_hidden_state
        embedded_text_input = self.encoder_dropout(F.leaky_relu(embedded_text_input))

        span_indices = self.triangle_mask.nonzero().unsqueeze(0).expand(
                        batch_size, -1, -1).to(device)

        if self.extractor_type in ['biaffine', 'selfbiaffine']:
            span_embeddings = self.extractor(sequence_tensor=embedded_text_input,
                                            span_indices=span_indices,
                                            adj_matrix=adj_matrix,
                                            sequence_mask=attention_mask)
        else:
            span_embeddings = self.extractor(sequence_tensor=embedded_text_input,
                                            span_indices=span_indices,
                                            sequence_mask=attention_mask)

        typing = self.pruner(span_embeddings)
        #span_mask = torch.ones((typing.shape[0], typing.shape[1])).bool().to(device)
        span_mask = torch.argmax(typing, dim=2).bool() #0.6
        #masked_span_embeddings = torch.mul(span_embeddings, type_mask.unsqueeze(-1))

        output = self.classifier(spans=span_indices,
                                span_mask=span_mask,
                                span_embeddings=span_embeddings)
        
        
        
        return output, typing.reshape(typing.shape[0], typing.shape[1]), span_indices

In [None]:
with open('/home/data_v3/labels2ids.pkl', 'rb') as pkl:
    labels2ids = pickle.load(pkl)

with open('/home/data_v3/ids2labels.pkl', 'rb') as pkl_two:
    ids2labels = pickle.load(pkl_two)

# Training

In [None]:
ner_ds_path = '/home/data_v3/train_texts.csv'
train_texts = pd.read_csv(ner_ds_path, sep=';')

to_del = [idx for idx, sent in enumerate(train_texts.Contents) if re.match(r'^\s+$', sent)]

train_texts = train_texts.drop(to_del).reset_index()

train_texts = train_texts.drop('index', axis=1)

train_inds = get_ind_sequence(train_texts)

# тренировочные лейблы
ner_ds_path = '/home/data_v3/train_spans.csv'
train_spans = pd.read_csv(ner_ds_path, sep=';')

In [None]:
train_texts_ds = dataset(train_texts, max_len=284)

val

In [None]:
ner_ds_path = '/home/data_v3/dev_texts.csv'
val_texts = pd.read_csv(ner_ds_path, sep=';')

to_del = [idx for idx, sent in enumerate(val_texts.Contents) if re.match(r'^\s+$', sent)]

val_texts = val_texts.drop(to_del).reset_index()

val_texts = val_texts.drop('index', axis=1)

val_inds = get_ind_sequence(val_texts)

ner_ds_path = '/home/data_v3/dev_spans.csv'
val_spans = pd.read_csv(ner_ds_path, sep=';')

In [None]:
val_texts_ds = dataset(val_texts, max_len=284)

test

In [None]:

ner_ds_path = '/home/data_v3/test_texts.csv'
test_texts = pd.read_csv(ner_ds_path, sep=';')

to_del = [idx for idx, sent in enumerate(test_texts.Contents) if re.match(r'^\s+$', sent)]

test_texts = test_texts.drop(to_del).reset_index()

test_texts = test_texts.drop('index', axis=1)

test_inds = get_ind_sequence(test_texts)

ner_ds_path = '/home/data_v3/test_spans.csv'
test_spans = pd.read_csv(ner_ds_path, sep=';')

In [None]:
test_texts_ds = dataset(test_texts, max_len=284)

In [None]:
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 1

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

val_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

test_params = {'batch_size': 1,
                'shuffle': False,
                'num_workers': 0
                }


training_loader = DataLoader(train_texts_ds, **train_params)
validation_loader = DataLoader(val_texts_ds, **val_params)
testing_loader = DataLoader(test_texts_ds, **test_params)

In [None]:
EPOCHS = 100
LEARNING_RATE = 1e-05

In [None]:
model = NNERModel(num_labels=len(ids2labels), max_span_len=26, max_seq_len=284,
                  extractor_type='biaffine') # attention
_ = model.to(device)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
optimizer = optim.Adam(model.parameters(),
                        lr=LEARNING_RATE)

total_steps = len(training_loader) * EPOCHS
warmup_steps = int(total_steps * 0.05)
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = total_steps)

In [None]:
safe_prefix = 'testing_version_num_35'

In [None]:
from_start = False # True при обучении модели с самого начала

In [None]:
if from_start == False:
    checkpoint = torch.load(f'/home/data_v3/{safe_prefix}_checkpoint.pth.tar')        
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    resume_epoch = checkpoint['epoch']
    last_best = checkpoint['f_score']
else:
    last_best = 0
    resume_epoch = 0

In [None]:
class Trainer(nn.Module):

    def __init__(self, model, optimizer, scheduler, ids2labels, labels2ids):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.circle_loss = CircleLoss(m=0.25, gamma=64)
        self.entropy = nn.CrossEntropyLoss()

        self.num_classes = len(ids2labels)
        self.labels2ids = labels2ids    

        self.evaluator = Evaluator(ids2labels)

    def forward(self, data_loader, labels_dataset, data_indices, mode='train'):
            
            self.predictions = []
            self.true_labels = []
            epoch_f_score = 0
            epoch_loss = 0

            grad_regulator = torch.no_grad

            if mode == 'train':
                self.model.train()
                grad_regulator = torch.enable_grad
            
            else:
                self.model.eval()

            with grad_regulator():
                for idx, batch in tqdm(enumerate(data_loader)):
                        if mode == 'train':
                            self.optimizer.zero_grad()


                        input_ids = batch['input_ids'].to(device)
                        attention_mask = batch['mask'].to(device)
                        adj = batch['adj'].to(device)

                        predicted, ent_type, span_indices = self.model(input_ids=input_ids,
                                                                    attention_mask=attention_mask,
                                                                    adj_matrix=adj)

                        # get batch labels
                        batch_addresses = batch['address']

                        class_labels, type_labels = get_batch_labels(batch_addresses,
                                                                     labels_dataset,
                                                                     data_indices,
                                                                     span_indices[0],
                                                                     self.labels2ids)
                        class_labels = class_labels.to(device)
                        type_labels = type_labels.to(device)

                        self.true_labels += class_labels.tolist()

                        abs_labels = predicted.max(2).indices# для f_score
                        self.predictions += abs_labels.tolist() 
                        
                        if mode == 'train':
                            # преобразование для вычисления ошибки
                            predicted = predicted.view(-1, self.num_classes)
                            class_labels = class_labels.view(-1)
                            ent_type = ent_type.view(-1)
                            type_labels = type_labels.view(-1)

                            active_tokens = class_labels.view(-1) != -100
                            class_labels = class_labels[active_tokens==1]
                            predicted = predicted[active_tokens==1]
                            ent_type = ent_type[active_tokens==1]
                            type_labels = type_labels[active_tokens==1]

                            # тест ce
                            #ce_loss = self.entropy(predicted, class_labels)
                            #loss = ce_loss

                            # circle loss
                            norm_preds = nn.functional.normalize(predicted)
                            inp_sp, inp_sn = convert_label_to_similarity(norm_preds, class_labels)
                            ccl_loss = self.circle_loss(inp_sp, inp_sn)
                            loss = ccl_loss
                                
                            # loss for type (ent\not-ent)
                            bcl_loss = self.bce_loss(ent_type, type_labels)
                            loss += bcl_loss

                            epoch_loss += loss.item()
                                
                        if mode == 'train':
                            loss.backward()
                            self.optimizer.step()
                            self.scheduler.step()
                        
                logger.info(f'RESULTS FOR MODE {mode.upper()}')
                epoch_f_score = self.evaluator.evaluate(self.true_labels, self.predictions)
                if mode == 'train':
                    epoch_loss = epoch_loss / len(data_loader)
                    logger.info(f'Loss per epoch: {epoch_loss}')
            
            return epoch_f_score

In [None]:
nner_trainer = Trainer(model=model,
                       optimizer=optimizer,
                       scheduler=scheduler,
                       ids2labels=ids2labels,
                       labels2ids=labels2ids
                       )

In [None]:
for epoch in range(EPOCHS):
    if resume_epoch+epoch+1 > EPOCHS:
        break
    logger.info(f'EPOCH {resume_epoch+epoch+1}/{EPOCHS}')
    _ = nner_trainer(training_loader, train_spans, train_inds, mode='train')
    logger.info('\n')
    f_score = nner_trainer(validation_loader, val_spans, val_inds, mode='dev')
    logger.info('\n')
    f_score = nner_trainer(testing_loader, test_spans, test_inds, mode='test')

    if f_score > last_best:
        last_best = f_score
        check_path = f'/home/data_v3/{safe_prefix}_checkpoint.pth.tar'
        torch.save({'epoch': epoch+1+resume_epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'f_score' : last_best}, check_path)
    logger.info('\n\n')