## Getting data from the files and turning it into a Dataset

Inspired by https://github.com/dialogue-evaluation/RuNNE

In [2]:
import pandas as pd
from datasets import Dataset

def list2str(labels):
    result = []
    for i in labels:
        result.append(' '.join([str(j) for j in i]))
    return result

df = pd.read_json("public_dat/train.jsonl", lines=True)
df["ners"] = df.apply(lambda x: list2str(x['ners']), axis=1)

# type(df['ners'])
train_dataset = Dataset.from_pandas(df)
df

Unnamed: 0,ners,sentences,id
0,"[0 5 CITY, 16 23 PERSON, 34 41 PERSON, 46 62 L...",Бостон взорвали Тамерлан и Джохар Царнаевы из ...,0
1,"[21 28 PROFESSION, 53 67 ORGANIZATION, 100 148...",Умер избитый до комы гитарист и сооснователь г...,1
2,"[0 4 PERSON, 37 42 COUNTRY, 47 76 ORGANIZATION...",Путин подписал распоряжение о выходе России из...,2
3,"[0 11 PERSON, 36 47 PROFESSION, 49 60 PERSON, ...",Бенедикт XVI носил кардиостимулятор\nПапа Римс...,3
4,"[0 4 PERSON, 17 29 ORGANIZATION, 48 56 PROFESS...",Обама назначит в Верховный суд латиноамериканк...,4
...,...,...,...
514,"[42 46 COUNTRY, 82 87 COUNTRY, 104 123 LOCATIO...",Глава Малайзии: мы не хотим противостоять Кита...,514
515,"[1 4 PRODUCT, 31 33 FACILITY, 35 44 TIME, 48 6...",«Союз» впервые пристыковался к МКС за 6 часов\...,515
516,"[0 4 PERSON, 8 12 PERSON, 45 52 AGE, 72 80 PRO...",Трамп и Путин сделали совместное заявление к 7...,516
517,"[0 9 NATIONALITY, 58 72 PERSON, 101 115 PERSON...",Российский магнат устроил самую дорогую свадьб...,517


In [4]:
dev_df = pd.read_json("public_dat/dev.jsonl", lines=True)

dev_dataset = Dataset.from_pandas(dev_df)

dev_dataset

Dataset({
    features: ['senences', 'id'],
    num_rows: 65
})

## Model Architecture

In [None]:
import pytorch_lightning as pl
import torch
import json
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

from transformers import (
    BertForTokenClassification,
    AdamW,
    get_linear_schedule_with_warmup,
    logging,
)
from tokenizers import BertWordPieceTokenizer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

# from iobes_flat_dataset import IOBESFlatRuNNEDataset, collate_to_max_length
# from score import Evaluator


class BaselineRuBERT(pl.LightningModule):

    def __init__(
        self, in_path, out_path, tag_to_id, total_steps, lr=1e-4, weight_decay=0.02
    ):

        super().__init__()

        self.model = BertForTokenClassification.from_pretrained(
            "DeepPavlov/rubert-base-cased", num_labels=29 * 4 + 1, return_dict=False
        )

        self.lr = lr
        self.total_steps = total_steps
        self.weight_decay = weight_decay

        self.tag_to_id = tag_to_id

        tags = [None] * (max(self.tag_to_id.values()) + 1)
        for tag, idx in self.tag_to_id.items():
            tags[idx] = tag

        self.id_to_tag = tags

        self.in_path = in_path
        self.out_path = out_path

    def configure_optimizers(self):

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

        optimizer = AdamW(
            optimizer_grouped_parameters, betas=(0.9, 0.999), lr=self.lr, eps=1e-6
        )

        t_total = self.total_steps
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.lr,
            pct_start=0.3,
            total_steps=t_total,
            anneal_strategy="linear",
        )

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):

        return self.model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
        )

    def training_step(self, batch, batch_idx):

        # batch:
        # [
        #     torch.LongTensor(tokens),
        #     torch.LongTensor(type_ids),
        #     torch.LongTensor(labels_ids),
        #     torch.LongTensor(offsets),
        #     torch.LongTensor([data["id"]]),
        #     context,
        #     filename,
        #     txtdata,
        #     tid,
        #     c_start,
        #     c_end
        # ]

        tokens, token_type_ids, labels = batch[0], batch[1], batch[2]

        attention_mask = (tokens != 0).long()
        loss = self(tokens, attention_mask, token_type_ids, labels=labels)[0]

        return {"loss": loss}

    def training_epoch_end(self, outputs):

        training_loss = sum([float(loss_dict["loss"]) for loss_dict in outputs])
        print("Loss on train: {:.6f}".format(training_loss))

        self.log("training_loss", training_loss)

    def validation_step(self, batch, batch_idx):

        (
            tokens,
            token_type_ids,
            labels,
            offsets,
            ids,
            contexts,
            filenames,
            txtdatas,
            tids,
            c_starts,
            c_ends,
        ) = batch

        attention_mask = (tokens != 0).long()
        loss, logits = self(tokens, attention_mask, token_type_ids, labels=labels)

        return {
            "loss": loss,
            "logits": logits,
            "labels": labels,
            "ids": ids,
            "offsets": offsets,
            "contexts": contexts,
            "filenames": filenames,
            "txtdatas": txtdatas,
            "tids": tids,
            "c_starts": c_starts,
            "c_ends": c_ends,
        }

    def validation_epoch_end(self, outputs):

        all_preds = []
        all_labels = []
        all_ids = []
        all_offsets = []
        all_contexts = []
        all_filenames = []
        all_txtdatas = []
        all_tids = []
        all_c_starts = []
        all_c_ends = []

        sum_loss = 0.0
        for output in outputs:

            loss = output["loss"]
            logits = output["logits"]
            labels = output["labels"]
            ids = output["ids"]
            offsets = output["offsets"]
            contexts = output["contexts"]
            filenames = output["filenames"]
            txtdatas = output["txtdatas"]
            tids = output["tids"]
            c_starts = output["c_starts"]
            c_ends = output["c_ends"]

            preds = torch.argmax(logits, dim=2)

            all_preds.extend(list(torch.split(preds, 1)))
            all_labels.extend(list(torch.split(labels, 1)))
            all_ids.extend(list(torch.split(ids, 1)))
            all_offsets.extend(list(torch.split(offsets, 1)))
            all_contexts.extend(contexts)
            all_filenames.extend(filenames)
            all_txtdatas.extend(txtdatas)
            all_tids.extend(tids)
            all_c_starts.extend(c_starts)
            all_c_ends.extend(c_ends)

            sum_loss += float(loss)

        print("\nLoss on dev: {:.6f}".format(sum_loss))

        self.log("validation_loss", sum_loss)

        sorted_zip = sorted(
            list(
                zip(
                    all_ids,
                    all_preds,
                    all_labels,
                    all_offsets,
                    all_contexts,
                    all_filenames,
                    all_txtdatas,
                    all_tids,
                    all_c_starts,
                    all_c_ends,
                )
            ),
            key=lambda x: x[0],
        )

        summary = self.compute_iobes_score(sorted_zip, mode="dev")

        self.log("mention_f1", summary["Mention F1"])
        self.log("mention_precision", summary["Mention precision"])
        self.log("mention_recall", summary["Mention recall"])
        self.log("macro_f1", summary["Macro F1"])
        self.log("macro_fewshot_f1", summary["Macro F1 few-shot"])

        return {}

## Training

In [None]:
VOCAB_PATH = "./vocab.txt"
NERS_PATH = "./eval/ref/ners.txt"
IN_PATH = "./eval"
OUT_PATH = "./eval"

TRAIN_PATH = "../data/train"
DEV_PATH = "../data/dev"
# TEST_PATH = "./data/test"

TRAIN_IDS_PATH = "../public_data/train.jsonl"
DEV_IDS_PATH = "../public_data/dev.jsonl"
# TEST_IDS_PATH = "./data/test.jsonl"

CKPT_PATH = "./checkpoints"
CKPT_FILE = "./checkpoints/epoch=206-step=37466.ckpt"

MAX_LEN = 128
BATCH_SIZE = 1
NUM_WORKERS = 8
MAX_EPOCHS = 1
LR = 1e-4
WEIGHT_DECAY = 0.02


logging.set_verbosity_error()

bertwptokenizer = BertWordPieceTokenizer(VOCAB_PATH, lowercase=False)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

dev_dataloader = DataLoader(
    dataset=dev_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

model = BaselineRuBERT(
    in_path=IN_PATH,
    out_path=OUT_PATH,
    tag_to_id=train_dataset.tag_to_id,
    total_steps=(len(train_dataset) // BATCH_SIZE) * MAX_EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_PATH,
    save_top_k=1,
    verbose=True,
    monitor="macro_f1",
    mode="max",
)

trainer = Trainer(
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=-1,
    max_epochs=1,
)

trainer.fit(model, train_dataloader, dev_dataloader)