In [17]:
from transformers import BertModel, BertTokenizer, AutoModelForTokenClassification, AutoTokenizer, DistilBertModel, DistilBertTokenizer
import pandas as pd
import pytorch_lightning as pl
from seqeval.metrics.sequence_labeling import get_entities
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF

torch.set_float32_matmul_precision("high")

In [27]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [2]:
# !pip install seqeval==1.2.2
# !pip install allennlp==2.10.1
# !pip install pytorch-crf

# Utils

Полезные функции для работы с BIO-тегами

In [18]:
BERT_MODEL_NAME = "distilbert-base-multilingual-cased"
tokenizer = DistilBertTokenizer.from_pretrained(BERT_MODEL_NAME)

In [4]:
def apply_bio_tagging(row):
    """
    По токенам чека и разметке (то есть выделенным товарам и брендам) строим BIO-теги
    """
    tokens = row["tokens"]["input_ids"]
    good = row["good"].split(',')[0].split()
    brand = row["brand"].split(',')[0].split()
    tags = ['O'] * len(tokens)
    for i, token in enumerate(tokens):
        if len(good) > 0 and tokens[i:i + len(good)] == good:
            tags[i] = "B-GOOD"
            for j in range(i + 1, i + len(good)):
                tags[j] = "I-GOOD"
        if len(brand) > 0 and tokens[i:i + len(brand)] == brand:
            tags[i] = "B-BRAND"
            for j in range(i + 1, i + len(brand)):
                tags[j] = "I-BRAND"
    return tags

Прямое и обратное преобразование тегов в индексы

In [5]:
index_to_tag = ["O", "B-GOOD", "I-GOOD", "B-BRAND", "I-BRAND", "PAD"]
tag_to_index = {tag: index for index, tag in enumerate(index_to_tag)}

# Datamodule

Подготовим данные для модели. Для этого определим наследника `torch.nn.utils.Dataset` - `ReceiptsDataset`

In [6]:
class ReceiptsDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.is_predict = "tags" not in df.columns
        self.data = df[["tokens", "good", "brand", "tags"]] if not self.is_predict else df[["tokens", "id"]]
        self.data = self.data.values

    def __getitem__(self, index):
        identifier = 0 if not self.is_predict else self.data[index][1]
        tokens = self.data[index][0]
        goods = self.data[index][1].split(',') if not self.is_predict else list()
        brands = self.data[index][2].split(',') if not self.is_predict else list()
        tags = self.data[index][3] if not self.is_predict else ["O"] * len(tokens["input_ids"])
        target = [tag_to_index[tag] for tag in tags]
        return identifier, tokens, goods, brands, target


    def __len__(self):
        return len(self.data)

Для объединения примеров в батчи нужна специальная `collate_fn`, в которой происходит паддинг

In [7]:
def collate_fn(batch):
    ids, tokens_sequence, goods, brands, targets = list(zip(*batch))

    # Соберите все input_ids и attention_masks в списки
    input_ids = [item['input_ids'] for item in tokens_sequence]
    attention_masks = [item['attention_mask'] for item in tokens_sequence]

    # Добавьте паддинг к input_ids, attention_masks и targets
    input_ids = pad_sequence([torch.LongTensor(ids) for ids in input_ids], batch_first=True)
    attention_masks = pad_sequence([torch.LongTensor(mask) for mask in attention_masks], batch_first=True)
    targets = pad_sequence([torch.LongTensor(target) for target in targets], batch_first=True, padding_value=tag_to_index["PAD"])

    return ids, input_ids, attention_masks, goods, brands, targets

Используем LightningDataModule для задания пайплайна

1. prepare_data
    1. Токенизируем текст
    2. Выделяем BIO-теги в размеченной части
2. setup
    1. Разделяем размеченную выборку на обучающую и валидационную
    2. Создаем `ReceiptsDataset` под каждую выборку

In [8]:
class ReceiptsDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_dataset_path,
                 test_dataset_path,
                 val_split_size,
                 batch_size,
                 num_workers):
        super().__init__()
        self.train_dataset_path = train_dataset_path
        self.test_dataset_path = test_dataset_path
        self.val_split_size = val_split_size
        self.batch_size = batch_size
        self.num_workers = num_workers


    def prepare_data(self):
        self.train_df = pd.read_csv(self.train_dataset_path).fillna("")
        self.test_df = pd.read_csv(self.test_dataset_path)

        self.train_df["tokens"] = self.train_df["name"].apply(lambda x: tokenizer.encode_plus(x, truncation=True, padding=False))
        self.test_df["tokens"] = self.test_df["name"].apply(lambda x: tokenizer.encode_plus(x, truncation=True, padding=False))

        self.train_df["tags"] = self.train_df.apply(apply_bio_tagging, axis=1)

    def setup(self, stage: str):
        self.train_df, self.val_df = train_test_split(self.train_df, test_size=self.val_split_size)

        self.train_dataset = ReceiptsDataset(self.train_df)
        self.val_dataset = ReceiptsDataset(self.val_df)
        self.predict_dataset = ReceiptsDataset(self.test_df)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          collate_fn=collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          collate_fn=collate_fn)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.predict_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           collate_fn=collate_fn)

In [9]:
TRAIN_DATASET_PATH = "data/train_supervised_dataset.csv"
TEST_DATASET_PATH = "data/test_dataset.csv"
VAL_SPLIT_SIZE = 0.1
BATCH_SIZE = 512
NUM_WORKERS = 4

In [10]:
dm = ReceiptsDataModule(
    TRAIN_DATASET_PATH,
    TEST_DATASET_PATH,
    VAL_SPLIT_SIZE,
    BATCH_SIZE,
    NUM_WORKERS
)

# Model

Сначала определим метрику `F1` для задачи NER

In [11]:
class F1Score:
    def __init__(self):
        self.tp = 0
        self.fp = 0
        self.fn = 0

    def update(self, pred, target):
        pred = frozenset(x for x in pred)
        target = frozenset(x for x in target)
        self.tp += len(pred & target)
        self.fp += len(pred - target)
        self.fn += len(target - pred)

    def reset(self):
        self.tp = 0
        self.fp = 0
        self.fn = 0

    def get(self):
        if self.tp == 0:
            return 0.0
        precision = self.tp / (self.tp + self.fp)
        recall = self.tp / (self.tp + self.fn)
        return 2 / (1 / precision + 1 / recall)

Зададим саму модель, ее шаги на обучении, валидации и инференсе, а также способ обучения

In [19]:
class ReceiptsModule(pl.LightningModule):
    def __init__(self, model_name, num_tags, learning_rate):
        super().__init__()
        self.model =  DistilBertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.model.config.hidden_size, num_tags)
        self.num_tags = num_tags
        
        self.crf = CRF(num_tags, batch_first=True)

#         self.criterion = nn.CrossEntropyLoss(ignore_index=tag_to_index["PAD"])
        self.learning_rate = learning_rate
        
        self.f1_good_train = F1Score()
        self.f1_brand_train = F1Score()
        self.f1_good_val = F1Score()
        self.f1_brand_val = F1Score()

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)
        
        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            return loss, logits
        else:
            return logits

    def training_step(self, batch, _):
        ids, tokens_sequence, attention_mask, goods, brands, targets = batch
        loss, logits = self(tokens_sequence, attention_mask, targets)
        
        tags_indices_sequence = self.crf.decode(logits)
        for i, tags_indices in enumerate(tags_indices_sequence):
            tags = [index_to_tag[index] for index in tags_indices[:len(tokens_sequence[i])]]
            entities = get_entities(tags)
            goods_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "GOOD"]
            brands_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "BRAND"]

            self.f1_good_train.update(goods_pred, goods[i])
            self.f1_brand_train.update(brands_pred, brands[i])
        self.log("loss/train", loss, on_epoch=True, batch_size=len(tags_indices_sequence))
        return loss

    def on_train_epoch_end(self):
        self.log("metric/f1_good_train", self.f1_good_train.get())
        self.log("metric/f1_brand_train", self.f1_brand_train.get())
        self.f1_good_train.reset()
        self.f1_brand_train.reset()

    def validation_step(self, batch, _):
        ids, tokens_sequence, attention_mask, goods, brands, targets = batch
        loss, logits = self(tokens_sequence, attention_mask, targets)
        
        tags_indices_sequence = self.crf.decode(logits)
        for i, tags_indices in enumerate(tags_indices_sequence):
            tags = [index_to_tag[index] for index in tags_indices[:len(tokens_sequence[i])]]
            entities = get_entities(tags)
            goods_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "GOOD"]
            brands_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "BRAND"]

            self.f1_good_train.update(goods_pred, goods[i])
            self.f1_brand_train.update(brands_pred, brands[i])
        self.log("loss/val", loss, on_epoch=True, batch_size=len(tags_indices_sequence))

    def on_validation_epoch_end(self):
        self.log("metric/f1_good_val", self.f1_good_val.get())
        self.log("metric/f1_brand_val", self.f1_brand_val.get())
        self.f1_good_val.reset()
        self.f1_brand_val.reset()

    def predict_step(self, batch, _):
        ids, tokens_sequence, attention_mask, goods, brands, targets = batch
        loss, logits = self(tokens_sequence, attention_mask, targets)
        
        tags_indices_sequence = self.crf.decode(logits)
        
        result = list()
        for i, tags_indices in enumerate(tags_indices_sequence):
            tags = [index_to_tag[index] for index in tags_indices[:len(tokens_sequence[i])]]
            entities = get_entities(tags)
            goods_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "GOOD"]
            brands_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "BRAND"]

            result.append([ids[i], goods_pred, brands_pred])
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.learning_rate)

In [20]:
LEARNING_RATE = 1e-4
model = ReceiptsModule(
    model_name=BERT_MODEL_NAME,
    num_tags=len(tag_to_index),
    learning_rate=LEARNING_RATE,
)

Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=[0],
    logger=pl.loggers.TensorBoardLogger("tb_logs", name="ner_bert_baseline"),
    max_epochs=70,
    log_every_n_steps=1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Обучение модели

In [22]:
trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type            | Params
-----------------------------------------------
0 | model      | DistilBertModel | 134 M 
1 | classifier | Linear          | 4.6 K 
2 | crf        | CRF             | 48    
-----------------------------------------------
134 M     Trainable params
0         Non-trainable params
134 M     Total params
538.955   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


Training: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Process Process-60:
Process Process-59:
Process Process-57:
Process Process-58:
Traceback (most recent call last):
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
Traceback (most recent call last):
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/worker/anaconda3/envs/py/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
  

	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  File "/home/worker/anaconda3/envs/py/lib/python3.9/threading.py", line 1080, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
  File "/home/worker/anaconda3/envs/py/lib/python3.9/threading.py", line 1080, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
KeyboardInterrupt
KeyboardInterrupt


Получение итоговых сущностей для тестового датасета

In [25]:
pred = trainer.predict(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 29it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [30]:
for batch in dm.train_dataloader():
    break

In [33]:
len(batch[0])

512

In [37]:
ids, tokens_sequence, attention_mask, goods, brands, targets = batch
# loss, logits = self(tokens_sequence, attention_mask, targets)

# tags_indices_sequence = self.crf.decode(logits)

In [46]:
loss, logits = model(tokens_sequence, attention_mask, targets)

In [47]:
loss.item()

0.0015055611729621887

In [48]:
logits.shape

torch.Size([512, 57, 6])

In [49]:
tags_indices_sequence = model.crf.decode(logits)

In [51]:
tags_indices_sequence

[[0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

In [None]:
        tags_indices_sequence = self.crf.decode(logits)
        for i, tags_indices in enumerate(tags_indices_sequence):
            tags = [index_to_tag[index] for index in tags_indices[:len(tokens_sequence[i])]]
            entities = get_entities(tags)
            goods_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "GOOD"]
            brands_pred = [' '.join(tokenizer.convert_ids_to_tokens(tokens_sequence[i][start:finish + 1])) for t, start, finish in entities if t == "BRAND"]


In [36]:
batch[1]

tensor([[   101,    532,  16481,  ...,      0,      0,      0],
        [   101,    525,  16859,  ...,      0,      0,      0],
        [   101,  11052, 108710,  ...,      0,      0,      0],
        ...,
        [   101,  70960,  11166,  ...,      0,      0,      0],
        [   101,    125,    119,  ...,      0,      0,      0],
        [   101,  12624,  10987,  ...,      0,      0,      0]])

In [28]:
pred

In [23]:
submission = pd.DataFrame(sum(pred, list()), columns=["id", "good", "brand"])
submission

Unnamed: 0,id,good,brand
0,0,клей,ермак
1,1,торт,
2,2,смеситель,
3,3,лимон,бар
4,4,коньяк,
...,...,...,...
4995,4995,рамка,
4996,4996,напиток,
4997,4997,наконечники,
4998,4998,шоколад,риттерспорт


In [24]:
submission.to_csv("submission_crf_lstm_preprocc.csv", index=False)

In [24]:
# del trainer
# del model
# del dm