In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
from ast import literal_eval
import os
from pathlib import Path
import pickle
import random

import augmenty
import pandas as pd
from sklearn.model_selection import train_test_split
import spacy
from spacy.tokens import DocBin
from tqdm import tqdm

In [69]:
spacypath = Path("../../data/ner/custom")

entities_for_multilabel = Path("../../data/ner/custom/entities_values_multilabel.pickle")

train_filepath = Path("../../data/ner/baseline/train.xlsx")
valid_filepath = Path("../../data/ner/baseline/valid.xlsx")
test_filepath = Path("../../data/ner/baseline/test.xlsx")

# Augmenty 

In [58]:
# Load Augmenters

keystroke_error_augmenter = augmenty.load("keystroke_error.v1", level=0.1, keyboard="ru.v1")

replace_dict = {
    "0": ["1", "2", "3", "4", "5", "6", "7", "8", "9"],
    "1": ["0", "2", "3", "4", "5", "6", "7", "8", "9"],
    "2": ["1", "0", "3", "4", "5", "6", "7", "8", "9"],
    "3": ["1", "2", "0", "4", "5", "6", "7", "8", "9"],
    "4": ["1", "2", "3", "0", "5", "6", "7", "8", "9"],
    "5": ["1", "2", "3", "4", "0", "6", "7", "8", "9"],
    "6": ["1", "2", "3", "4", "5", "0", "7", "8", "9"],
    "7": ["1", "2", "3", "4", "5", "6", "0", "8", "9"],
    "8": ["1", "2", "3", "4", "5", "6", "7", "0", "9"],
    "9": ["1", "2", "3", "4", "5", "6", "7", "8", "0"]
}
char_replace_augmenter = augmenty.load("char_replace.v1", level=0.9, replace=replace_dict)

with open(entities_for_multilabel, "rb") as f:
    entity_values = pickle.load(f)

# Prepare data and split it

In [59]:
train = pd.read_excel(train_filepath, engine="openpyxl")
valid = pd.read_excel(valid_filepath, engine="openpyxl")
test = pd.read_excel(test_filepath, engine="openpyxl")

In [60]:
# TODO Replace it with 

def create_data(df, textcol="text", labelcol="markup"):
    
    data = []
    
    for index, row in tqdm(df.iterrows(), total=len(df)):
        line = row[textcol]
        entities = literal_eval(row[labelcol])

        entities_filtered = []
        for entity in entities:
            start, end, label = entity[0], entity[1], entity[2]
            
            entities_filtered.append((start, end, label))

        data.append((line, entities_filtered))
    
    return data

In [61]:
train_data = create_data(train)
valid_data = create_data(valid)
test_data = create_data(test)

100%|███████████████████████████████████████| 814/814 [00:00<00:00, 7015.00it/s]
100%|███████████████████████████████████████| 283/283 [00:00<00:00, 7642.41it/s]
100%|███████████████████████████████████████| 130/130 [00:00<00:00, 8100.84it/s]


In [62]:
len(train_data), len(valid_data), len(test_data)

(814, 283, 130)

## Do Augmentations

In [63]:
nlp = spacy.blank("ru")

def use_augmeny(text, augmenter, nlp):
    texts = augmenty.texts([text], augmenter=augmenter, nlp=nlp)
    texts = [_text for _text in texts if len(_text) == len(text)]
    return texts


def replace_entity_with_same_length(text, annotations, entity_dict, n=3):
    
    texts = []
    
    
    for i in range(n):
        
        _text = text[:]
        
        for annotation in annotations:
            start, end, label = annotation[0], annotation[1], annotation[2]
            value = text[start:end]
            value_length = len(value)
            
            values_with_same_length = [_ for _ in entity_dict[label] if len(_) == value_length]
            if len(values_with_same_length) == 0:
                try:
                    value_augmented = "".join(random.choices(["1", "2", "3", 
                                                              "4","5", "6", "7", "8", "9"], value_length))
                except TypeError:
                    value_augmented = value[:]
            else:
                value_augmented = random.choice(values_with_same_length)
            
            _text = _text.replace(value, value_augmented)
            
            
        texts.append(_text)
    
    
    return texts

In [64]:
def augment_data(data):
    
    data_augmented = []
    
    for text, annotations in data:
        
        texts_augmented = []
        texts_augmented.append(text)
        
        texts_digits = use_augmeny(text, augmenter=char_replace_augmenter, nlp=nlp)
        texts_keystroke = use_augmeny(text, augmenter=keystroke_error_augmenter, nlp=nlp)
        texts_entity_replace = replace_entity_with_same_length(text, annotations, entity_dict=entity_values)
        
        texts_augmented.extend(texts_digits)
        texts_augmented.extend(texts_keystroke)
        texts_augmented.extend(texts_entity_replace)
        
    
        for text in texts_augmented:
            data_augmented.append((text, annotations))
    
    return data_augmented 

In [65]:
train_data = augment_data(train_data)
valid_data = augment_data(valid_data)
test_data = augment_data(test_data)

In [66]:
len(train_data), len(valid_data), len(test_data)

(4884, 1698, 780)

In [68]:
nlp = spacy.blank("ru")
        
        
def create_spacy_object(data, savepath, mode):
    
    db = DocBin()
    erros_overlapping_entities = 0
    for i, (text, annotations) in enumerate(data):
        if type(text) == str:
            doc = nlp(text)
            ents = []
            if len(annotations):
                for start, end, label in annotations:
                    span = doc.char_span(start, end, label=label)
                    if span:
                        ents.append(span)
            
#             try:
            doc.spans["txs"] = ents
#             except ValueError:
#                 erros_overlapping_entities += 1
#                 continue    
            db.add(doc)
        else:
            print(text, annotations)
    
    dbpath = f"{savepath}/{mode}.spacy"
    db.to_disk(dbpath)
    print(f"Saved to {dbpath}. {erros_overlapping_entities} Docs were not processed")
    return db

In [70]:
db_train = create_spacy_object(
    data=train_data,
    savepath=spacypath,
    mode="train"
)

db_valid = create_spacy_object(
    data=valid_data,
    savepath=spacypath,
    mode="valid"
)

db_test = create_spacy_object(
    data=test_data,
    savepath=spacypath,
    mode="test"
)

Saved to ../../data/ner/custom/train.spacy. 0 Docs were not processed
Saved to ../../data/ner/custom/valid.spacy. 0 Docs were not processed
Saved to ../../data/ner/custom/test.spacy. 0 Docs were not processed


In [71]:
len(db_train), len(db_valid), len(db_test)

(4884, 1698, 780)

# Create config

In [None]:
%%bash

python -m spacy init fill-config base_config.cfg config.cfg

# Run training (better in comand line)

In [None]:
%%bash

python -m spacy train config.cfg --gpu-id 0 --output ./output --paths.train ./spacy/train.spacy --paths.dev ./spacy/valid.spacy