In [1]:
import pandas as pd
import spacy
from sklearn.model_selection import train_test_split
from spacy.tokens import DocBin


In [2]:
df = pd.read_csv('../../../raw_data/data_prep.csv')

In [3]:
nlp = spacy.load('ru_core_news_lg')

In [4]:
def set_cat_sm(text):
    if 'A' in text:
        return 'A'
    elif 'B' in text:
        return 'B'
    return 'B'

In [5]:
def create_docbin(data):
    db = DocBin()
    categories = ['A', 'B']
    for i in range(data.shape[0]):    
        doc = nlp.make_doc(str(data["tokenized_str"][i]))
        doc.cats = {category: 0 for category in categories}
        doc.cats[data["category_sm"][i]] = 1
        db.add(doc)
    return db

In [6]:
def get_prection(text):
    doc = nlp(text)
    scores = doc.cats
    return max(scores, key=scores.get)

In [7]:
df['category_sm'] = df["category"].apply(set_cat_sm)

In [8]:
df['category_sm'].value_counts()

A    738
B    462
Name: category_sm, dtype: int64

In [9]:
train, dev = train_test_split(df, test_size=0.1, random_state=42)

In [10]:
train = train.reset_index(drop=True)
dev = dev.reset_index(drop=True)

In [11]:
create_docbin(dev).to_disk("../sym/dev.spacy")
create_docbin(train).to_disk("../sym/train.spacy")

In [12]:
! python -m spacy train ../sym/config.cfg --output ../sym/sym_model --paths.train ../sym/train.spacy --paths.dev ../sym/dev.spacy


[38;5;4mℹ Saving to output directory: ..\sym\sym_model[0m
[38;5;4mℹ Using CPU[0m
[1m
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['textcat'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TEXTCAT  CATS_SCORE  SCORE 
---  ------  ------------  ----------  ------
  0       0          0.25       38.46    0.38
  0     200         58.28       73.86    0.74
  0     400         41.73       76.10    0.76
  0     600         25.54       81.71    0.82
  0     800         19.73       82.22    0.82
  0    1000         22.14       90.42    0.90
  1    1200         10.54       83.56    0.84
  1    1400          8.34       88.39    0.88
  1    1600          8.71       82.91    0.83
  1    1800          9.81       83.69    0.84
  1    2000          8.68       90.34    0.90
  2    2200          4.01       88.49    0.88
  2    2400          1.69       89.98    0.90
  2    2600          1.24       83.96    0.84
[38;5;2m✔ Saved pipeline to output directory[0m
..\sy