In [1]:
from tqdm.notebook import tqdm
import math
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling

In [2]:
class MaskedLanguageModelingDataset(Dataset):
    def __init__(self, encoded_data):
        self.encoded_data = encoded_data

    def __len__(self):
        return len(self.encoded_data)

    def __getitem__(self, idx):
        item = self.encoded_data[idx]
        return {
            "input_ids": item["input_ids"].squeeze(), 
            "attention_mask": item["attention_mask"].squeeze(), 
            "labels": item["labels"].squeeze(),
        }

In [3]:
train_dataset = torch.load('./train.pt')
val_dataset = torch.load('./val.pt')
test_dataset = torch.load('./test.pt')

In [4]:
# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)  # No need to shuffle validation data

In [5]:
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruBert-large")
model = AutoModelForMaskedLM.from_pretrained("ai-forever/ruBert-large")

Some weights of the model checkpoint at ai-forever/ruBert-large were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device("cuda")
model = model.to(device)

In [7]:
# Set up optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-7)

In [8]:
epochs = 30

In [9]:
# Define the early stopping criteria
patience = 2  # Number of epochs with no improvement after which training will be stopped
stop_early = 0
min_val_loss = float('inf')

In [10]:
with open("model_results.txt", "w", encoding="utf8") as m_file:
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
        for batch in progress_bar:
            # Move tensors to the same device as the model
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
    
            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
    
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
            total_loss += loss.item()
            progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
        
        if epoch % 5 == 0:
            torch.save(model.state_dict(), f'model/checkpoint_{epoch}.pth')
    
        avg_train_loss = total_loss / len(train_dataloader)
        print(f"epoch {epoch}: Training Loss: {avg_train_loss:.2f}", file=m_file, flush=True)
        # tqdm.write(f'Training Loss: {avg_train_loss:.2f}')
    
        # Validation loop
        model.eval()
        total_val_loss = 0
        val_progress_bar = tqdm(val_dataloader, desc='Validation', leave=False, disable=False)
        for batch in val_progress_bar:  # Assuming you have a separate DataLoader for validation data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
    
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
    
            total_val_loss += loss.item()
    
        avg_val_loss = total_val_loss / len(val_dataloader)
        print(f"epoch {epoch}: Validation Loss: {avg_val_loss:.2f}", file=m_file, flush=True)
        # tqdm.write(f'Validation Loss: {avg_val_loss:.2f}')
    
        # Check if the validation loss has improved
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
            stop_early = 0
    
            # Save the model
            torch.save(model.state_dict(), 'model/newest_bert_best_model.pth')
        else:
            stop_early += 1
    
        # If the validation loss hasn't improved for 'patience' epochs, stop the training
        if stop_early >= patience:
            print("Early stopping triggered")
            print(f"Early stopping triggered at {epoch} epoch", file=m_file)
            break

Epoch 0:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 10:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 11:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 12:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 13:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 14:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Epoch 15:   0%|          | 0/1080 [00:00<?, ?it/s]

Validation:   0%|          | 0/288 [00:00<?, ?it/s]

Early stopping triggered
