In [13]:
import pathlib

import pandas as pd
import torch.optim
from dataset import PunctuationRestorationDataset, collate_fn
from model import BertRestorePunctuation
from tokenization import PunctuationTokenizer
from torch import nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import AutoModel


def train_single_epoch(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    dataloader: DataLoader,
    device: torch.device = torch.device("cpu"),
    suffix: str = ""
) -> None:
    model.train()
    pbar = tqdm(dataloader)
    for x, attention_mask, y in pbar:
        optimizer.zero_grad()
        inputs, attention_mask, targets = (
            x.to(device),
            attention_mask.to(device),
            y.to(device),
        )
        outputs = model(inputs, attention_mask)
        loss = criterion(outputs.permute(0, 2, 1), targets)
        loss.backward()
        optimizer.step()
        pbar.set_description(f"Batch Loss: {loss.item():.4f}")
    with open(f"/external1/svpilipenko/model_{suffix}.pth", "wb") as fd:
        torch.save(model.state_dict(), fd)


@torch.no_grad()
def predict_single_text(
    model: nn.Module,
    text: str,
    device: torch.device
) -> str:
    model.eval()
    inputs, attention_mask, _ = tokenizer.encode(text)
    outputs = torch.argmax(
        model(inputs.unsqueeze(0).to(device), attention_mask.unsqueeze(0).to(device)),
        dim=-1
    ).squeeze(0)
    return tokenizer.decode(inputs, attention_mask, outputs)


corpus = pd.read_csv("/external1/svpilipenko/lenta-ru-news.csv")["text"].dropna().reset_index(drop=True)
model_name = "cointegrated/rubert-tiny"
bert = AutoModel.from_pretrained(model_name, output_hidden_states=True)
tokenizer = PunctuationTokenizer(model_name, truncation_threshold=510)
cache_file_path = pathlib.Path(".") / "tokenizer_vocab.json"
tokenizer.load_vocab(cache_file_path)


dataset = PunctuationRestorationDataset(corpus, tokenizer)
dataloader = DataLoader(dataset, shuffle=True, batch_size=32, collate_fn=collate_fn)

Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
device = (
    torch.device("cuda:6") if torch.cuda.is_available() else torch.device("cpu")
)
model = BertRestorePunctuation(
    bert, num_classes=len(tokenizer.punctuation_vocab)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
torch.backends.cudnn.benchmark = True

In [15]:
for epoch in range(1, 5):
    train_single_epoch(model, optimizer, criterion, dataloader, device, suffix=str(epoch))
    print(f"Epoch #{epoch} passed...")


del model
del bert
torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, max=25031.0), HTML(value='')))

KeyboardInterrupt: 

'Бои уСопоцкинаи Друскеникзакончились отступлением германцев Неприятельприблизившись с севера кОсовцу начал артиллерий скую борьбускрепость юВ артиллерийском бою принимают участие тяжелые кали бры Сраннего утра14 сентября огонь до стигзначительно го на пря жения По пытка германской пехоты пробить ся бли жек крепости отражена В Галиции мы заняли Дем бицу Большая колонна отступавшая пошоссе отПеремышлякСанок у обстреливалась с высотнашей батареей и бежала бросив парки обози автомобили Выла з'