In [1]:
import warnings
warnings.filterwarnings("ignore")
from tqdm.notebook import tqdm
import time

import matplotlib.pyplot as plt

from datasets import load_dataset
from datasets import Dataset, DatasetDict

from transformers import AutoTokenizer, AutoModelForMaskedLM

import pandas as pd
import numpy as np
import re, os

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# from huggingface_hub import notebook_login
# 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
# notebook_login()

pd.set_option('display.max_colwidth', None)

In [2]:
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using "{device.upper()}" device.')

Currently using "CUDA" device.


In [3]:
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/ruRoberta-large")

In [4]:
with open("qa_arith.txt", encoding="utf-8") as f:
    train = f.read()

with open("qa_arith_test.txt", encoding="utf-8") as f:
    test = f.read()

train = train.split("\n\n")
test = test.split("\n\n")

In [5]:
train = pd.DataFrame({"text": train})
train["first"] = train["text"].apply(lambda x: x.split("A: ")[0].replace("\n", "").replace("Q: ", ""))
train["second"] = train["text"].apply(lambda x: x.split("A: ")[-1])

train = train[train["second"].str.isdigit()]
train["second"] = train["second"].astype(np.float32)

test = pd.DataFrame({"text": test})
test["first"] = test["text"].apply(lambda x: x.split("A: ")[0].replace("\n", "").replace("Q: ", ""))
test["second"] = test["text"].apply(lambda x: x.split("A: ")[-1])

test = test[test["second"].str.isdigit()]
test["second"] = test["second"].astype(np.float32)

train.head(2)

Unnamed: 0,text,first,second
0,"\nQ: Утром литовец покормил котика 2 раза, а после обеда еще 16 раз. Сколько всего раз он покормил котика?\nA: 18","Утром литовец покормил котика 2 раза, а после обеда еще 16 раз. Сколько всего раз он покормил котика?",18.0
1,"Q: В понедельник кузнец сфотографировал 19 пирожков, а во вторник еще 16. Сколько всего пирожков он сфотографировал за эти два дня?\nA: 35","В понедельник кузнец сфотографировал 19 пирожков, а во вторник еще 16. Сколько всего пирожков он сфотографировал за эти два дня?",35.0


In [6]:
def tokenize_for_roberta(text):
    input_ids = []
    attention_masks = []
    
    for sent in tqdm(text):
        encoded_sent = tokenizer(sent, padding='max_length', truncation=True, max_length=64)
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

In [7]:
train_ids, train_masks = tokenize_for_roberta(train["first"].values)
train_labels = torch.FloatTensor(train["second"].values)

test_ids, test_masks = tokenize_for_roberta(test["first"].values)
test_labels = torch.FloatTensor(test["second"].values)

train_ds = TensorDataset(train_ids, train_masks, train_labels)
test_ds = TensorDataset(test_ids, test_masks, test_labels)

train_dataloader = DataLoader(train_ds, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=8, shuffle=False)

  0%|          | 0/73939 [00:00<?, ?it/s]

  0%|          | 0/18447 [00:00<?, ?it/s]

In [8]:
class MeanPooling(nn.Module):
    
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings

In [9]:
class RobertaModel(nn.Module):
    def __init__(self, dropout=0.1, freeze=False, **kwargs):
        super(RobertaModel, self).__init__(**kwargs)
        
        self.extractor = AutoModelForMaskedLM.from_pretrained("sberbank-ai/ruRoberta-large").to(device)
        if freeze:
            for param in self.extractor.parameters():
                param.requires_grad = False
                
        self.drop = nn.Dropout(p=dropout)
        self.pooler = MeanPooling()
        self.fc = nn.Linear(1024, 1)
        
        """
        self.fc = nn.Sequential(nn.Linear(1024, 512),
                                nn.ReLU(),
                                nn.Dropout(p=dropout),
                                nn.Linear(512, 1))
        """
        
    def forward(self, ids, mask):        
        out = self.extractor(input_ids=ids, attention_mask=mask, output_hidden_states=True) 
        out = self.pooler(out.hidden_states[-1], mask)  # [8, 64, 1024]
        out = self.drop(out)
        outputs = self.fc(out)
        return outputs

In [10]:
@torch.no_grad()
def validate_one_batch(data, model, criterion):
    model.eval()
    data = [d.to(device) for d in data]
    texts, masks, labels = data
    
    out = model(texts, masks)
    loss = criterion(out, labels)
    
    return loss.item()

In [13]:
class EarlyStopping:
    def __init__(self, patience=2, min_delta=0, path='model.pth'):
        self.path = path
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model=None, **kwargs):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            checkpoint = {'model': model, }
            torch.save(checkpoint, self.path)
            print(f'Model saved to: {self.path}')
            self.best_loss = val_loss
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

In [16]:
roberta = RobertaModel(dropout=0.2, freeze=True).to(device)

criterion = nn.MSELoss()

optimizer = torch.optim.AdamW(roberta.fc.parameters(), lr=0.0001, weight_decay=0.01)  # freeze roberta parameters
# optimizer = torch.optim.AdamW(roberta.parameters(), lr=2e-5, weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, min_lr=1e-6, factor=0.1)
stopper = EarlyStopping(patience=5)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=125, eta_min=1e-8)

In [14]:
def run(model, epochs=3, print_freq=500, num_accumulation_steps=2):
    # best_loss = np.inf
    
    for epoch in range(1, epochs+1):

        train_loss = []
        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            time_1 = time.time()
            
            model.train()
            batch = [d.to(device) for d in batch]
            texts, masks, labels = batch
            
            output = model(texts, masks)
            loss = criterion(output, labels)
            train_loss.append(loss.item())
            
            loss = loss / num_accumulation_steps
            loss.backward()
            
            if ((step + 1) % num_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                # scheduler.step()  # cosine

            if (step+1) % print_freq == 0:
                print('epoch:', epoch, 
                      '\tstep:', step+1, '/', len(train_dataloader),
                      '\ttrain loss:', '{:.4f}'.format(loss),
                      '\ttime:', '{:.4f}'.format((time.time()-time_1)*print_freq), 's')

        valid_loss = []
        for step, batch in tqdm(enumerate(tqdm(test_dataloader)), total=len(test_dataloader)):
            loss = validate_one_batch(batch, model, criterion)
            valid_loss.append(loss)
        
        stopper(np.mean(valid_loss), model)
        scheduler.step(np.mean(valid_loss))  # resuce on plateau
        print('epoch:', epoch, '/', epochs,
              '\ttrain loss:', '{:.4f}'.format(np.mean(train_loss)),
              '\tvalid loss:', '{:.4f}'.format(np.mean(valid_loss)),
             )
        # if np.mean(valid_loss) < best_loss:
        #     best_loss = np.mean(valid_loss)
        #     torch.save({'model': model}, "calc_model.pth")

In [17]:
run(roberta, epochs=16)
roberta = torch.load("model.pth", map_location=device)["model"]

  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 1 	step: 500 / 9243 	train loss: 780228.7500 	time: 36.5059 s
epoch: 1 	step: 1000 / 9243 	train loss: 1331226.5000 	time: 35.2907 s
epoch: 1 	step: 1500 / 9243 	train loss: 174514.0938 	time: 34.8475 s
epoch: 1 	step: 2000 / 9243 	train loss: 1263771.7500 	time: 34.9973 s
epoch: 1 	step: 2500 / 9243 	train loss: 22744.0371 	time: 34.9996 s
epoch: 1 	step: 3000 / 9243 	train loss: 3815086.5000 	time: 37.5005 s
epoch: 1 	step: 3500 / 9243 	train loss: 93935.8281 	time: 38.2247 s
epoch: 1 	step: 4000 / 9243 	train loss: 1246.4340 	time: 37.7791 s
epoch: 1 	step: 4500 / 9243 	train loss: 486993.4688 	time: 35.4112 s
epoch: 1 	step: 5000 / 9243 	train loss: 4572265.5000 	time: 36.3410 s
epoch: 1 	step: 5500 / 9243 	train loss: 397546.1562 	time: 34.7111 s
epoch: 1 	step: 6000 / 9243 	train loss: 3125989.0000 	time: 37.9559 s
epoch: 1 	step: 6500 / 9243 	train loss: 467878.0000 	time: 34.6428 s
epoch: 1 	step: 7000 / 9243 	train loss: 793099.5000 	time: 39.4552 s
epoch: 1 	step: 7500

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

epoch: 1 / 16 	train loss: 3618106.7284 	valid loss: 3467607.5241


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 2 	step: 500 / 9243 	train loss: 4505588.0000 	time: 35.0037 s
epoch: 2 	step: 1000 / 9243 	train loss: 161541.2812 	time: 40.0083 s
epoch: 2 	step: 1500 / 9243 	train loss: 4742413.0000 	time: 35.0146 s
epoch: 2 	step: 2000 / 9243 	train loss: 429542.7500 	time: 35.0049 s
epoch: 2 	step: 2500 / 9243 	train loss: 751019.8125 	time: 54.9971 s
epoch: 2 	step: 3000 / 9243 	train loss: 1792000.2500 	time: 34.9991 s
epoch: 2 	step: 3500 / 9243 	train loss: 5625169.0000 	time: 34.9939 s
epoch: 2 	step: 4000 / 9243 	train loss: 2651382.5000 	time: 39.9998 s
epoch: 2 	step: 4500 / 9243 	train loss: 2837918.0000 	time: 34.9956 s
epoch: 2 	step: 5000 / 9243 	train loss: 1368206.6250 	time: 35.0001 s
epoch: 2 	step: 5500 / 9243 	train loss: 6137807.5000 	time: 35.0389 s
epoch: 2 	step: 6000 / 9243 	train loss: 15380.4316 	time: 39.9997 s
epoch: 2 	step: 6500 / 9243 	train loss: 61704.9883 	time: 35.0012 s
epoch: 2 	step: 7000 / 9243 	train loss: 121816.7344 	time: 40.0128 s
epoch: 2 	step:

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 2 / 16 	train loss: 3421550.5885 	valid loss: 3292732.1941


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 3 	step: 500 / 9243 	train loss: 1013403.7500 	time: 34.9938 s
epoch: 3 	step: 1000 / 9243 	train loss: 5286512.0000 	time: 39.9957 s
epoch: 3 	step: 1500 / 9243 	train loss: 71608.3125 	time: 39.9677 s
epoch: 3 	step: 2000 / 9243 	train loss: 2384494.5000 	time: 40.0006 s
epoch: 3 	step: 2500 / 9243 	train loss: 3543054.0000 	time: 34.9985 s
epoch: 3 	step: 3000 / 9243 	train loss: 687971.7500 	time: 34.9977 s
epoch: 3 	step: 3500 / 9243 	train loss: 27057.7500 	time: 34.9981 s
epoch: 3 	step: 4000 / 9243 	train loss: 5647901.0000 	time: 39.9960 s
epoch: 3 	step: 4500 / 9243 	train loss: 268778.2500 	time: 40.0112 s
epoch: 3 	step: 5000 / 9243 	train loss: 4932272.0000 	time: 35.0363 s
epoch: 3 	step: 5500 / 9243 	train loss: 35385.0469 	time: 39.9694 s
epoch: 3 	step: 6000 / 9243 	train loss: 309069.1875 	time: 35.0213 s
epoch: 3 	step: 6500 / 9243 	train loss: 34902.3242 	time: 35.0109 s
epoch: 3 	step: 7000 / 9243 	train loss: 52300.3242 	time: 34.9927 s
epoch: 3 	step: 7500

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 3 / 16 	train loss: 3275820.2183 	valid loss: 3166533.2200


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 4 	step: 500 / 9243 	train loss: 2093843.2500 	time: 34.9917 s
epoch: 4 	step: 1000 / 9243 	train loss: 119186.3438 	time: 34.9903 s
epoch: 4 	step: 1500 / 9243 	train loss: 823428.0625 	time: 35.0101 s
epoch: 4 	step: 2000 / 9243 	train loss: 48477.6289 	time: 39.9663 s
epoch: 4 	step: 2500 / 9243 	train loss: 4481759.0000 	time: 39.9945 s
epoch: 4 	step: 3000 / 9243 	train loss: 95327.5312 	time: 35.0001 s
epoch: 4 	step: 3500 / 9243 	train loss: 286808.6875 	time: 39.9998 s
epoch: 4 	step: 4000 / 9243 	train loss: 4154322.5000 	time: 40.0010 s
epoch: 4 	step: 4500 / 9243 	train loss: 178266.6875 	time: 40.0056 s
epoch: 4 	step: 5000 / 9243 	train loss: 3376702.0000 	time: 34.9699 s
epoch: 4 	step: 5500 / 9243 	train loss: 3065041.0000 	time: 40.0025 s
epoch: 4 	step: 6000 / 9243 	train loss: 2532404.5000 	time: 34.9988 s
epoch: 4 	step: 6500 / 9243 	train loss: 104356.3125 	time: 35.0047 s
epoch: 4 	step: 7000 / 9243 	train loss: 2892514.0000 	time: 39.9939 s
epoch: 4 	step: 

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 4 / 16 	train loss: 3171720.4655 	valid loss: 3081812.6045


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 5 	step: 500 / 9243 	train loss: 644402.1250 	time: 40.0001 s
epoch: 5 	step: 1000 / 9243 	train loss: 132078.4219 	time: 39.9969 s
epoch: 5 	step: 1500 / 9243 	train loss: 122592.1875 	time: 35.0326 s
epoch: 5 	step: 2000 / 9243 	train loss: 365579.4375 	time: 34.9960 s
epoch: 5 	step: 2500 / 9243 	train loss: 587089.5000 	time: 34.9997 s
epoch: 5 	step: 3000 / 9243 	train loss: 5386731.0000 	time: 34.9960 s
epoch: 5 	step: 3500 / 9243 	train loss: 1259062.2500 	time: 40.0009 s
epoch: 5 	step: 4000 / 9243 	train loss: 76416.7500 	time: 39.9994 s
epoch: 5 	step: 4500 / 9243 	train loss: 850704.3750 	time: 35.0322 s
epoch: 5 	step: 5000 / 9243 	train loss: 139926.5938 	time: 39.9638 s
epoch: 5 	step: 5500 / 9243 	train loss: 2433520.0000 	time: 40.0099 s
epoch: 5 	step: 6000 / 9243 	train loss: 2014755.0000 	time: 34.9950 s
epoch: 5 	step: 6500 / 9243 	train loss: 3806402.7500 	time: 35.0107 s
epoch: 5 	step: 7000 / 9243 	train loss: 1568809.0000 	time: 35.0051 s
epoch: 5 	step: 

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 5 / 16 	train loss: 3101924.1399 	valid loss: 3025792.7062


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 6 	step: 500 / 9243 	train loss: 2321145.2500 	time: 35.0395 s
epoch: 6 	step: 1000 / 9243 	train loss: 2576202.0000 	time: 40.0003 s
epoch: 6 	step: 1500 / 9243 	train loss: 509227.2500 	time: 40.0021 s
epoch: 6 	step: 2000 / 9243 	train loss: 101854.3125 	time: 35.0406 s
epoch: 6 	step: 2500 / 9243 	train loss: 111196.7031 	time: 35.0006 s
epoch: 6 	step: 3000 / 9243 	train loss: 243802.9375 	time: 40.0075 s
epoch: 6 	step: 3500 / 9243 	train loss: 2916701.5000 	time: 35.0038 s
epoch: 6 	step: 4000 / 9243 	train loss: 1172777.2500 	time: 34.9698 s
epoch: 6 	step: 4500 / 9243 	train loss: 1080270.6250 	time: 35.0348 s
epoch: 6 	step: 5000 / 9243 	train loss: 559203.6250 	time: 39.2280 s
epoch: 6 	step: 5500 / 9243 	train loss: 106093.1875 	time: 38.2019 s
epoch: 6 	step: 6000 / 9243 	train loss: 107483.6953 	time: 34.9950 s
epoch: 6 	step: 6500 / 9243 	train loss: 1585364.7500 	time: 37.5001 s
epoch: 6 	step: 7000 / 9243 	train loss: 131620.8594 	time: 38.5137 s
epoch: 6 	step:

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 6 / 16 	train loss: 3057201.9537 	valid loss: 2989921.9505


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 7 	step: 500 / 9243 	train loss: 1716174.5000 	time: 40.6249 s
epoch: 7 	step: 1000 / 9243 	train loss: 489339.5000 	time: 39.4164 s
epoch: 7 	step: 1500 / 9243 	train loss: 148279.5625 	time: 37.6941 s
epoch: 7 	step: 2000 / 9243 	train loss: 2833139.0000 	time: 40.2324 s
epoch: 7 	step: 2500 / 9243 	train loss: 3236418.0000 	time: 38.3807 s
epoch: 7 	step: 3000 / 9243 	train loss: 3668527.5000 	time: 34.9966 s
epoch: 7 	step: 3500 / 9243 	train loss: 1207281.7500 	time: 39.9978 s
epoch: 7 	step: 4000 / 9243 	train loss: 156152.9062 	time: 39.1008 s
epoch: 7 	step: 4500 / 9243 	train loss: 1474822.2500 	time: 37.9734 s
epoch: 7 	step: 5000 / 9243 	train loss: 223808.5625 	time: 39.9969 s
epoch: 7 	step: 5500 / 9243 	train loss: 174012.0312 	time: 40.8949 s
epoch: 7 	step: 6000 / 9243 	train loss: 260281.5625 	time: 38.2804 s
epoch: 7 	step: 6500 / 9243 	train loss: 270622.1875 	time: 40.0389 s
epoch: 7 	step: 7000 / 9243 	train loss: 5688628.0000 	time: 40.0039 s
epoch: 7 	step

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 7 / 16 	train loss: 3025711.4275 	valid loss: 2965805.3908


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 8 	step: 500 / 9243 	train loss: 305811.9062 	time: 38.8560 s
epoch: 8 	step: 1000 / 9243 	train loss: 155576.7812 	time: 40.2997 s
epoch: 8 	step: 1500 / 9243 	train loss: 1989103.3750 	time: 42.0010 s
epoch: 8 	step: 2000 / 9243 	train loss: 2715311.5000 	time: 40.3157 s
epoch: 8 	step: 2500 / 9243 	train loss: 2883716.5000 	time: 38.9982 s
epoch: 8 	step: 3000 / 9243 	train loss: 134073.0000 	time: 37.3336 s
epoch: 8 	step: 3500 / 9243 	train loss: 668639.3750 	time: 39.1003 s
epoch: 8 	step: 4000 / 9243 	train loss: 2532245.0000 	time: 39.6705 s
epoch: 8 	step: 4500 / 9243 	train loss: 132834.3438 	time: 35.0333 s
epoch: 8 	step: 5000 / 9243 	train loss: 867735.0000 	time: 39.9939 s
epoch: 8 	step: 5500 / 9243 	train loss: 598683.6875 	time: 39.9961 s
epoch: 8 	step: 6000 / 9243 	train loss: 1135590.3750 	time: 35.0002 s
epoch: 8 	step: 6500 / 9243 	train loss: 607780.7500 	time: 39.0003 s
epoch: 8 	step: 7000 / 9243 	train loss: 187339.0469 	time: 38.9998 s
epoch: 8 	step: 

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 8 / 16 	train loss: 3005640.1590 	valid loss: 2951299.8879


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 9 	step: 500 / 9243 	train loss: 3144862.0000 	time: 35.0031 s
epoch: 9 	step: 1000 / 9243 	train loss: 1077929.0000 	time: 40.0009 s
epoch: 9 	step: 1500 / 9243 	train loss: 788731.2500 	time: 39.9886 s
epoch: 9 	step: 2000 / 9243 	train loss: 245668.8438 	time: 35.0012 s
epoch: 9 	step: 2500 / 9243 	train loss: 443526.5312 	time: 35.0014 s
epoch: 9 	step: 3000 / 9243 	train loss: 1545951.2500 	time: 39.9661 s
epoch: 9 	step: 3500 / 9243 	train loss: 868973.0000 	time: 40.0076 s
epoch: 9 	step: 4000 / 9243 	train loss: 1443458.3750 	time: 34.9817 s
epoch: 9 	step: 4500 / 9243 	train loss: 160362.0938 	time: 40.0012 s
epoch: 9 	step: 5000 / 9243 	train loss: 1408904.5000 	time: 39.9971 s
epoch: 9 	step: 5500 / 9243 	train loss: 781957.3125 	time: 39.9578 s
epoch: 9 	step: 6000 / 9243 	train loss: 157769.8125 	time: 39.9971 s
epoch: 9 	step: 6500 / 9243 	train loss: 844908.2500 	time: 35.0024 s
epoch: 9 	step: 7000 / 9243 	train loss: 168719.9531 	time: 40.5570 s
epoch: 9 	step: 

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 9 / 16 	train loss: 2992855.6979 	valid loss: 2941433.5604


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 10 	step: 500 / 9243 	train loss: 683250.9375 	time: 40.0039 s
epoch: 10 	step: 1000 / 9243 	train loss: 171327.4062 	time: 39.9870 s
epoch: 10 	step: 1500 / 9243 	train loss: 759520.0000 	time: 39.4994 s
epoch: 10 	step: 2000 / 9243 	train loss: 1232053.3750 	time: 39.5066 s
epoch: 10 	step: 2500 / 9243 	train loss: 4398243.5000 	time: 40.0031 s
epoch: 10 	step: 3000 / 9243 	train loss: 322359.4688 	time: 39.9914 s
epoch: 10 	step: 3500 / 9243 	train loss: 177734.1094 	time: 39.9982 s
epoch: 10 	step: 4000 / 9243 	train loss: 288423.6250 	time: 39.9939 s
epoch: 10 	step: 4500 / 9243 	train loss: 4400385.0000 	time: 39.9911 s
epoch: 10 	step: 5000 / 9243 	train loss: 176182.0625 	time: 34.9996 s
epoch: 10 	step: 5500 / 9243 	train loss: 2445030.7500 	time: 39.9952 s
epoch: 10 	step: 6000 / 9243 	train loss: 1086411.1250 	time: 39.9622 s
epoch: 10 	step: 6500 / 9243 	train loss: 1427954.1250 	time: 35.0047 s
epoch: 10 	step: 7000 / 9243 	train loss: 1978200.6250 	time: 40.0010 s


  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 10 / 16 	train loss: 2982519.5029 	valid loss: 2935005.9168


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 11 	step: 500 / 9243 	train loss: 3687509.7500 	time: 39.9604 s
epoch: 11 	step: 1000 / 9243 	train loss: 238210.2812 	time: 40.0003 s
epoch: 11 	step: 1500 / 9243 	train loss: 238650.7812 	time: 35.0024 s
epoch: 11 	step: 2000 / 9243 	train loss: 1577852.6250 	time: 39.9948 s
epoch: 11 	step: 2500 / 9243 	train loss: 1706239.5000 	time: 40.0006 s
epoch: 11 	step: 3000 / 9243 	train loss: 907468.3750 	time: 40.0027 s
epoch: 11 	step: 3500 / 9243 	train loss: 210644.0000 	time: 39.9992 s
epoch: 11 	step: 4000 / 9243 	train loss: 3002018.0000 	time: 34.9956 s
epoch: 11 	step: 4500 / 9243 	train loss: 1012213.2500 	time: 40.0044 s
epoch: 11 	step: 5000 / 9243 	train loss: 3063286.0000 	time: 40.0000 s
epoch: 11 	step: 5500 / 9243 	train loss: 247928.1250 	time: 40.0040 s
epoch: 11 	step: 6000 / 9243 	train loss: 171699.6875 	time: 34.9981 s
epoch: 11 	step: 6500 / 9243 	train loss: 206602.2344 	time: 34.9991 s
epoch: 11 	step: 7000 / 9243 	train loss: 790120.0000 	time: 39.9998 s
e

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 11 / 16 	train loss: 2976726.0992 	valid loss: 2930549.9782


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 12 	step: 500 / 9243 	train loss: 444488.0000 	time: 39.9953 s
epoch: 12 	step: 1000 / 9243 	train loss: 2173069.5000 	time: 39.9982 s
epoch: 12 	step: 1500 / 9243 	train loss: 1914354.0000 	time: 39.9983 s
epoch: 12 	step: 2000 / 9243 	train loss: 232283.3125 	time: 40.0004 s
epoch: 12 	step: 2500 / 9243 	train loss: 737745.6250 	time: 39.9964 s
epoch: 12 	step: 3000 / 9243 	train loss: 261604.2500 	time: 40.0000 s
epoch: 12 	step: 3500 / 9243 	train loss: 997674.6875 	time: 39.9983 s
epoch: 12 	step: 4000 / 9243 	train loss: 248974.0000 	time: 34.9866 s
epoch: 12 	step: 4500 / 9243 	train loss: 205303.3281 	time: 34.9908 s
epoch: 12 	step: 5000 / 9243 	train loss: 2484511.5000 	time: 34.9948 s
epoch: 12 	step: 5500 / 9243 	train loss: 3796610.5000 	time: 35.0033 s
epoch: 12 	step: 6000 / 9243 	train loss: 2627513.5000 	time: 35.0289 s
epoch: 12 	step: 6500 / 9243 	train loss: 2468960.5000 	time: 40.0002 s
epoch: 12 	step: 7000 / 9243 	train loss: 2219771.7500 	time: 39.9662 s


  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 12 / 16 	train loss: 2972499.5591 	valid loss: 2927514.2489


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 13 	step: 500 / 9243 	train loss: 693046.7500 	time: 35.0051 s
epoch: 13 	step: 1000 / 9243 	train loss: 212194.2188 	time: 40.0000 s
epoch: 13 	step: 1500 / 9243 	train loss: 1788975.1250 	time: 35.0397 s
epoch: 13 	step: 2000 / 9243 	train loss: 1544893.0000 	time: 35.0057 s
epoch: 13 	step: 2500 / 9243 	train loss: 4151353.5000 	time: 35.0004 s
epoch: 13 	step: 3000 / 9243 	train loss: 1353272.7500 	time: 40.0000 s
epoch: 13 	step: 3500 / 9243 	train loss: 206201.4219 	time: 35.0049 s
epoch: 13 	step: 4000 / 9243 	train loss: 1792006.7500 	time: 40.0002 s
epoch: 13 	step: 4500 / 9243 	train loss: 2769057.2500 	time: 35.0333 s
epoch: 13 	step: 5000 / 9243 	train loss: 708582.1250 	time: 39.9951 s
epoch: 13 	step: 5500 / 9243 	train loss: 3052479.7500 	time: 34.9975 s
epoch: 13 	step: 6000 / 9243 	train loss: 384801.5625 	time: 35.0002 s
epoch: 13 	step: 6500 / 9243 	train loss: 186538.4375 	time: 35.0358 s
epoch: 13 	step: 7000 / 9243 	train loss: 3213196.5000 	time: 34.9929 s

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 13 / 16 	train loss: 2969427.7207 	valid loss: 2925656.1351


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 14 	step: 500 / 9243 	train loss: 2170388.7500 	time: 35.0003 s
epoch: 14 	step: 1000 / 9243 	train loss: 742388.4375 	time: 35.0089 s
epoch: 14 	step: 1500 / 9243 	train loss: 372378.3125 	time: 35.0003 s
epoch: 14 	step: 2000 / 9243 	train loss: 914614.8125 	time: 39.9944 s
epoch: 14 	step: 2500 / 9243 	train loss: 1702918.0000 	time: 35.0316 s
epoch: 14 	step: 3000 / 9243 	train loss: 3426972.0000 	time: 39.9971 s
epoch: 14 	step: 3500 / 9243 	train loss: 5172105.0000 	time: 39.9964 s
epoch: 14 	step: 4000 / 9243 	train loss: 277331.9375 	time: 35.0000 s
epoch: 14 	step: 4500 / 9243 	train loss: 176611.7500 	time: 34.9962 s
epoch: 14 	step: 5000 / 9243 	train loss: 250364.2031 	time: 39.9987 s
epoch: 14 	step: 5500 / 9243 	train loss: 4152430.5000 	time: 39.9946 s
epoch: 14 	step: 6000 / 9243 	train loss: 3494129.0000 	time: 34.9969 s
epoch: 14 	step: 6500 / 9243 	train loss: 1131069.0000 	time: 35.0124 s
epoch: 14 	step: 7000 / 9243 	train loss: 262815.3750 	time: 35.0044 s


  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 14 / 16 	train loss: 2967484.8620 	valid loss: 2923975.6493


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 15 	step: 500 / 9243 	train loss: 4354184.0000 	time: 40.0001 s
epoch: 15 	step: 1000 / 9243 	train loss: 2774218.0000 	time: 35.0012 s
epoch: 15 	step: 1500 / 9243 	train loss: 240366.4688 	time: 40.0057 s
epoch: 15 	step: 2000 / 9243 	train loss: 290412.0312 	time: 39.9959 s
epoch: 15 	step: 2500 / 9243 	train loss: 4110050.0000 	time: 40.0004 s
epoch: 15 	step: 3000 / 9243 	train loss: 1878318.1250 	time: 40.0006 s
epoch: 15 	step: 3500 / 9243 	train loss: 1786775.5000 	time: 40.0001 s
epoch: 15 	step: 4000 / 9243 	train loss: 349400.9375 	time: 38.1900 s
epoch: 15 	step: 4500 / 9243 	train loss: 212588.8438 	time: 39.9941 s
epoch: 15 	step: 5000 / 9243 	train loss: 4549679.5000 	time: 40.0058 s
epoch: 15 	step: 5500 / 9243 	train loss: 1063706.2500 	time: 39.9998 s
epoch: 15 	step: 6000 / 9243 	train loss: 463711.1875 	time: 39.9944 s
epoch: 15 	step: 6500 / 9243 	train loss: 1574546.5000 	time: 35.0019 s
epoch: 15 	step: 7000 / 9243 	train loss: 4521129.0000 	time: 40.0015 

  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 15 / 16 	train loss: 2965894.9570 	valid loss: 2923065.8091


  0%|          | 0/9243 [00:00<?, ?it/s]

epoch: 16 	step: 500 / 9243 	train loss: 815322.0000 	time: 34.9950 s
epoch: 16 	step: 1000 / 9243 	train loss: 784889.5000 	time: 34.9960 s
epoch: 16 	step: 1500 / 9243 	train loss: 4742601.5000 	time: 40.0019 s
epoch: 16 	step: 2000 / 9243 	train loss: 2257537.0000 	time: 40.0022 s
epoch: 16 	step: 2500 / 9243 	train loss: 1271860.8750 	time: 39.9948 s
epoch: 16 	step: 3000 / 9243 	train loss: 455303.1250 	time: 40.0006 s
epoch: 16 	step: 3500 / 9243 	train loss: 295170.1875 	time: 35.0312 s
epoch: 16 	step: 4000 / 9243 	train loss: 261102.3281 	time: 35.0004 s
epoch: 16 	step: 4500 / 9243 	train loss: 1600075.3750 	time: 35.0062 s
epoch: 16 	step: 5000 / 9243 	train loss: 242036.9375 	time: 35.0007 s
epoch: 16 	step: 5500 / 9243 	train loss: 789821.6875 	time: 35.0032 s
epoch: 16 	step: 6000 / 9243 	train loss: 3544526.0000 	time: 34.9686 s
epoch: 16 	step: 6500 / 9243 	train loss: 1694489.3750 	time: 39.9967 s
epoch: 16 	step: 7000 / 9243 	train loss: 2028059.1250 	time: 40.0039 s


  0%|          | 0/2306 [00:00<?, ?it/s]

  0%|          | 0/2306 [00:00<?, ?it/s]

Model saved to: model.pth
epoch: 16 / 16 	train loss: 2966185.4360 	valid loss: 2922059.2891


In [18]:
def predict(model, prompt):
    encoded_sent = tokenizer(prompt, padding='max_length', truncation=True, max_length=64, return_tensors="pt")
    encoded_sent = {k: v.to(device) for k, v in encoded_sent.items()}
    model.eval()
    with torch.no_grad():
        answer = model(ids=encoded_sent["input_ids"], mask=encoded_sent["attention_mask"])
    answer = answer.squeeze().item()
    
    return answer

In [22]:
predict(roberta, "Сколько будет, если к 35 прибавить 15?")

837.192626953125

In [15]:
class MLMLoss(nn.Module):
    # mask answer in sentence
    def forward(self, pred, true):
        mask = torch.equal(true, tokenizer.mask_token_id)
        mask = torch.FloatTensor(mask)
        loss = nn.functional.cross_entropy(pred, true, reduction="none")
        
        loss *= mask
        loss = torch.sum(loss) / torch.sum(mask)
        return loss