In [1]:
import pandas as pd
import spacy
from sklearn.model_selection import train_test_split
from spacy.tokens import DocBin


In [2]:
df = pd.read_csv('../../../raw_data/data_prep.csv')

In [3]:
nlp = spacy.load('ru_core_news_lg')

In [4]:
df['category'].value_counts()

A      386
BBB    270
AA     199
AAA    153
BB     136
B       44
C       12
Name: category, dtype: int64

In [5]:
cat1 = ['AAA', 'AA']
cat2 = ['B', 'A']
cat3 = ['C', 'BBB', 'BB']

In [6]:
def set_cat_sm(text):
    if text in cat1:
        return '1'
    elif text in cat2:
        return '2'
    return '3'

In [7]:
def create_docbin(data):
    db = DocBin()
    categories = ['1', '2', '3']
    for i in range(data.shape[0]):    
        doc = nlp.make_doc(str(data["tokenized_str"][i]))
        doc.cats = {category: 0 for category in categories}
        doc.cats[data["category_sm"][i]] = 1
        db.add(doc)
    return db

In [8]:
df['category_sm'] = df["category"].apply(set_cat_sm)

In [9]:
df['category_sm'].value_counts()

2    430
3    418
1    352
Name: category_sm, dtype: int64

In [10]:
train, dev = train_test_split(df, test_size=0.2, random_state=42)

In [11]:
train = train.reset_index(drop=True)
dev = dev.reset_index(drop=True)

In [12]:
create_docbin(dev).to_disk("dev.spacy")
create_docbin(train).to_disk("train.spacy")

In [13]:
! python -m spacy train ../sym/config.cfg --output ./sym_model --paths.train train.spacy --paths.dev dev.spacy


[38;5;4mℹ Saving to output directory: sym_model[0m
[38;5;4mℹ Using CPU[0m
[1m
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['textcat_multilabel'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TEXTC...  CATS_SCORE  SCORE 
---  ------  -------------  ----------  ------
  0       0           0.25       50.44    0.50
  0     200          42.03       78.71    0.79
  0     400          33.66       82.23    0.82
  0     600          36.73       84.07    0.84
  0     800          25.79       87.02    0.87
  1    1000          27.29       88.14    0.88
  1    1200           8.14       88.33    0.88
  1    1400          11.56       89.17    0.89
  1    1600          12.34       90.26    0.90
  1    1800          12.54       90.66    0.91
  2    2000           9.58       91.39    0.91
  2    2200           5.36       90.26    0.90
  2    2400           3.64       91.10    0.91
  2    2600           4.02       91.30    0.91
  2    2800           4.57       91