In [43]:
import pandas as pd
import spacy
from sklearn.model_selection import train_test_split
from spacy.tokens import DocBin

In [44]:
df = pd.read_csv('../../../raw_data/data_prep.csv')

In [45]:
nlp = spacy.load('ru_core_news_lg')

In [46]:
df['category'].value_counts()

A      386
BBB    270
AA     199
AAA    153
BB     136
B       44
C       12
Name: category, dtype: int64

In [47]:
cat1 = ['AAA', 'AA']
cat2 = ['B', 'A']
cat3 = ['C', 'BBB', 'BB']

In [48]:
def set_cat_sm(text):
    if text in cat1:
        return '1'
    elif text in cat2:
        return '2'
    return '3'

In [49]:
def create_docbin(data):
    db = DocBin()
    categories = ['1', '2', '3']
    for i in range(data.shape[0]):    
        doc = nlp.make_doc(str(data["tokenized_str"][i]))
        doc.cats = {category: 0 for category in categories}
        doc.cats[data["category_sm"][i]] = 1
        db.add(doc)
    return db

In [50]:
df['category_sm'] = df["category"].apply(set_cat_sm)

In [51]:
df['category_sm'].value_counts()

2    430
3    418
1    352
Name: category_sm, dtype: int64

In [52]:
train, dev = train_test_split(df, test_size=0.1, random_state=42)

In [53]:
train = train.reset_index(drop=True)
dev = dev.reset_index(drop=True)

In [54]:
create_docbin(dev).to_disk("dev.spacy")
create_docbin(train).to_disk("train.spacy")

In [55]:
! python -m spacy train ../sym/config.cfg --output ./sym_model --paths.train train.spacy --paths.dev dev.spacy

[38;5;4mℹ Saving to output directory: sym_model[0m
[38;5;4mℹ Using CPU[0m
[1m
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['textcat_multilabel'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TEXTC...  CATS_SCORE  SCORE 
---  ------  -------------  ----------  ------
  0       0           0.25       54.75    0.55
  0     200          50.47       76.40    0.76
  0     400          38.08       80.71    0.81
  0     600          32.49       84.06    0.84
  0     800          23.85       86.75    0.87
  0    1000          27.58       89.58    0.90
  1    1200          12.63       87.72    0.88
  1    1400          11.88       88.84    0.89
  1    1600          13.55       90.35    0.90
  1    1800           8.65       90.13    0.90
  1    2000           8.70       91.29    0.91
  2    2200           8.45       90.17    0.90
  2    2400           3.21       91.33    0.91
  2    2600           5.16       91.62    0.92
  2    2800           4.89       90