In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.metrics import mean_squared_error, accuracy_score
import json
from tqdm import tqdm
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm
2024-11-01 20:27:00.431950: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-01 20:27:00.490455: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-01 20:27:00.490495: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-01 20:27:00.491912: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-01 20:27:00.5

In [2]:
# Load data
with open('data/poetry_data_train.json', 'r', encoding='utf-8') as file:
    data = json.load(file)


data[0]

{'url': 'https://www.chitalnya.ru/work/3180020/',
 'rating': '29',
 'views': 33,
 'output_text': 'Люблю ли осень? Ты спроси у ели -\nОна в колючках и не ждёт тепла...\nЕй хрен один: повсюду лишь метели,\nКапель ли хнычет ночью со стекла.\n\nБагряный лист, ли, кружится прощаясь,\nЛи фонаря торчащего кадык,\nНа низком небе пятнами касаясь\nК луне общаясь, притулился встык.\n\nЛюблю ли осень? Ты спроси у ветра.\nС ним всё равно не сладить в парусах.\nПо осени ли, по весне для шторма,\nСпускают их, на реи намотав.\n\nПожухлых трав не кошенные дали,\nЧьих жалких тел охапками нажнут.\nСпроси у них, о том они мечтали\nВесной, тогда!? Теперь их просто жгут.\n\nПо осени, когда седеют чувства,\nЗастыли ветви рук на деревах -\nМне хочется упасть на дно колодца,\nЧто бы не видеть слёзы на глазах.',
 'genre': 'лирика'}

In [3]:
# Dataset class
class PoemRatingDataset(Dataset):
    def __init__(self, poems, ratings, tokenizer):
        self.poems = poems
        self.ratings = ratings
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, idx):
        poem = self.poems[idx]
        rating = self.ratings[idx]
        encoding = self.tokenizer(
            poem,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item['labels'] = torch.tensor(rating, dtype=torch.float)
        return item

In [26]:
poems = [entry['output_text'] for entry in data if float(entry['rating']) > 0 and entry['views'] >= 50]
ratings = [float(entry['rating']) for entry in data if float(entry['rating']) > 0 and entry['views'] >= 50]

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased')
dataset = PoemRatingDataset(poems, ratings, tokenizer)

# Train-validation split
train_size = int(0.95 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16)

In [5]:
from transformers import BertModel

# Model definition
class BertForRegression(torch.nn.Module):
    def __init__(self):
        super(BertForRegression, self).__init__()
        self.bert = BertModel.from_pretrained('DeepPavlov/rubert-base-cased')
        self.regressor = torch.nn.Linear(self.bert.config.hidden_size, 1)
        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        rating = self.regressor(pooled_output)
        if labels is not None:
            loss_fn = torch.nn.MSELoss()
            loss = loss_fn(rating.view(-1), labels.view(-1))
            return loss, rating
        return rating

In [27]:
# Initialize model, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForRegression().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
# Training with validation metrics
num_epochs = 32
model.train()
for epoch in range(num_epochs):
    total_train_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch.get('token_type_ids', None).to(device)
        labels = batch['labels'].to(device)

        loss, _ = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels
        )
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    val_predictions = []
    val_labels = []
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch.get('token_type_ids', None).to(device)
            labels = batch['labels'].to(device)
    
            loss, predictions = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        labels=labels
                    )

            val_loss += loss.item()
            val_predictions.extend(predictions.cpu().numpy().flatten())
            val_labels.extend(labels.cpu().numpy().flatten())

        val_mse = mean_squared_error(val_labels, val_predictions)
        print(f'Epoch {epoch+1} completed. Train Loss: {total_train_loss / len(train_dataloader):.4f}, Validation MSE: {val_mse:.4f}')

    # Save the model and tokenizer
    os.mkdir(f'models_checkpoints/{epoch}')
    torch.save(model.state_dict(), f'models_checkpoints/{epoch}/regression_poem_rating_model_filtered.pth')
    tokenizer.save_pretrained(f'models_checkpoints/{epoch}/regression_poem_rating_model_filtered')

Epoch 1/32: 100%|██████████| 2237/2237 [04:38<00:00,  8.03batch/s]


Epoch 1 completed. Train Loss: 3669.2557, Validation MSE: 3162.4363


Epoch 2/32: 100%|██████████| 2237/2237 [04:33<00:00,  8.18batch/s]


Epoch 2 completed. Train Loss: 3165.4602, Validation MSE: 2861.1584


Epoch 3/32: 100%|██████████| 2237/2237 [04:33<00:00,  8.19batch/s]


Epoch 3 completed. Train Loss: 2984.2838, Validation MSE: 2776.5127


Epoch 4/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.20batch/s]


Epoch 4 completed. Train Loss: 2937.1735, Validation MSE: 2748.0017


Epoch 5/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.20batch/s]


Epoch 5 completed. Train Loss: 2931.9633, Validation MSE: 2732.2371


Epoch 6/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 6 completed. Train Loss: 2908.0497, Validation MSE: 2720.1118


Epoch 7/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.20batch/s]


Epoch 7 completed. Train Loss: 2894.6096, Validation MSE: 2709.9883


Epoch 8/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.20batch/s]


Epoch 8 completed. Train Loss: 2893.1187, Validation MSE: 2701.1777


Epoch 9/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.22batch/s]


Epoch 9 completed. Train Loss: 2880.5236, Validation MSE: 2693.6377


Epoch 10/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 10 completed. Train Loss: 2879.8679, Validation MSE: 2687.1343


Epoch 11/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.21batch/s]


Epoch 11 completed. Train Loss: 2868.4032, Validation MSE: 2681.2354


Epoch 12/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 12 completed. Train Loss: 2859.3086, Validation MSE: 2676.4548


Epoch 13/32: 100%|██████████| 2237/2237 [04:32<00:00,  8.22batch/s]


Epoch 13 completed. Train Loss: 2856.3177, Validation MSE: 2672.3147


Epoch 14/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.25batch/s]


Epoch 14 completed. Train Loss: 2850.3777, Validation MSE: 2668.9626


Epoch 15/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 15 completed. Train Loss: 2850.9102, Validation MSE: 2665.6926


Epoch 16/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 16 completed. Train Loss: 2847.5975, Validation MSE: 2662.7878


Epoch 17/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 17 completed. Train Loss: 2846.9402, Validation MSE: 2660.1528


Epoch 18/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 18 completed. Train Loss: 2849.7261, Validation MSE: 2657.9385


Epoch 19/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 19 completed. Train Loss: 2843.7281, Validation MSE: 2655.6018


Epoch 20/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 20 completed. Train Loss: 2839.0633, Validation MSE: 2653.9353


Epoch 21/32: 100%|██████████| 2237/2237 [04:30<00:00,  8.26batch/s]


Epoch 21 completed. Train Loss: 2842.6346, Validation MSE: 2652.2273


Epoch 22/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 22 completed. Train Loss: 2839.5618, Validation MSE: 2650.7461


Epoch 23/32: 100%|██████████| 2237/2237 [04:30<00:00,  8.27batch/s]


Epoch 23 completed. Train Loss: 2839.7816, Validation MSE: 2649.1301


Epoch 24/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 24 completed. Train Loss: 2836.2177, Validation MSE: 2647.8801


Epoch 25/32: 100%|██████████| 2237/2237 [04:30<00:00,  8.26batch/s]


Epoch 25 completed. Train Loss: 2834.0895, Validation MSE: 2646.8328


Epoch 26/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 26 completed. Train Loss: 2833.2021, Validation MSE: 2645.7217


Epoch 27/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 27 completed. Train Loss: 2830.6131, Validation MSE: 2644.9097


Epoch 28/32: 100%|██████████| 2237/2237 [04:30<00:00,  8.26batch/s]


Epoch 28 completed. Train Loss: 2829.6218, Validation MSE: 2644.1179


Epoch 29/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.23batch/s]


Epoch 29 completed. Train Loss: 2828.0773, Validation MSE: 2643.3384


Epoch 30/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 30 completed. Train Loss: 2831.0470, Validation MSE: 2642.1218


Epoch 31/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.24batch/s]


Epoch 31 completed. Train Loss: 2830.2438, Validation MSE: 2641.7744


Epoch 32/32: 100%|██████████| 2237/2237 [04:31<00:00,  8.25batch/s]


Epoch 32 completed. Train Loss: 2829.2167, Validation MSE: 2641.1838


### Расчет NDCG на отложенной выборке

In [6]:
# Load data
with open('data/poetry_data_test.json', 'r', encoding='utf-8') as file:
    test_data = json.load(file)


test_data[0]

{'url': 'https://www.chitalnya.ru/work/2999707/',
 'rating': '0',
 'views': 15,
 'output_text': 'Нынче головы четыре у дракона,\nПожиратель президентов он мастак.\nНа него управы нет, и нет закона,\nВ споре с ним же попадёт любой впросак.\n\nУ него в кармане сотни триллионов.\nДля него богач - обычный нищеброд.\nНету больше никаких на свете тронов,\nНи во что не ставит он любой народ.\n\nОн прикажет слугам - выпилить любого\nС интернета, те его прогонят прочь,\nИ никто не пикнет, не промолвит слова,\nНекому бедняге на земле помочь.\n\nИ всё потому, что присягнули змею,\nОтступив от Бога, скорбен наш удел.\nСлужат люди молча мерзкому пигмею,\nПоголовно, будто вовсе оборзев.\n\nНынче люди лживы, врут напраполую.\nОн сегодня баба, завтра же мужик.\nГрабят, убивают, даже мать родную\nМогут погубить, коль сам дракон велит.\n\nВремена драконьи - времена инферно.\nКовид пандемия - следствие его,\nОбезумевши, спешит планету ввергнуть\nВ бездну - так уже желает бошинство.',
 'genre': 'лирика'}

In [8]:
path = '/root/work/bert_regression/regression_poem_rating_model_filtered.pth'
model = BertForRegression()

model.load_state_dict(torch.load(path, weights_only=True))

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [10]:
poems = [entry['output_text'] for entry in test_data]
ratings = [float(entry['rating']) for entry in test_data]

tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased')
test_dataset = PoemRatingDataset(poems, ratings, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [12]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import ndcg_score
import numpy as np
from tqdm import tqdm 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  

all_predictions = []
all_true_ratings = []

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask'] 
        
        inputs = inputs.to(device)
        attention_mask = attention_mask.to(device)

        outputs = model(inputs, attention_mask=attention_mask)
        predictions = outputs.cpu().numpy()

        all_predictions.extend(predictions.flatten())
        all_true_ratings.extend(batch['labels'].cpu().numpy())

y_true = np.array(all_true_ratings)
y_scores = np.array(all_predictions)

y_true_reshaped = y_true.reshape(1, -1)  # Reshape to match the expected input shape
y_scores_reshaped = y_scores.reshape(1, -1)

ndcg = ndcg_score(y_true_reshaped, y_scores_reshaped)

print(f"NDCG Score: {ndcg}")


  0%|          | 0/1082 [00:00<?, ?it/s]

100%|██████████| 1082/1082 [04:51<00:00,  3.71it/s]

NDCG Score: 0.6982737505125434





### Pairwise Approach

In [69]:
class PairwisePoemRankingDataset(Dataset):
    def __init__(self, poems, ratings, tokenizer):
        self.poems = poems
        self.ratings = ratings
        self.tokenizer = tokenizer
        self.pairs = self.create_pairs()

    def create_pairs(self):
        # Create pairs of (poem1, poem2, label), where label is 1 if poem1 > poem2, otherwise 0
        pairs = []
        for i in range(len(self.poems)):
            for j in range(i + 1, len(self.poems)):
                if self.ratings[i] != self.ratings[j]:
                    higher = i if self.ratings[i] > self.ratings[j] else j
                    lower = j if higher == i else i
                    pairs.append((self.poems[higher], self.poems[lower], 1))  # higher should rank above lower
                    pairs.append((self.poems[lower], self.poems[higher], 0))  # lower should rank below higher
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        poem1, poem2, label = self.pairs[idx]
        encoding1 = self.tokenizer(
            poem1,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )
        encoding2 = self.tokenizer(
            poem2,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        item = {
            'input_ids1': encoding1['input_ids'].squeeze(0),
            'attention_mask1': encoding1['attention_mask'].squeeze(0),
            'input_ids2': encoding2['input_ids'].squeeze(0),
            'attention_mask2': encoding2['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float)
        }
        return item


dataset = PairwisePoemRankingDataset(poems, ratings, tokenizer)

# Train-validation split
train_size = int(0.95 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16)

In [71]:
# Pairwise ranking model
class BertForPairwiseRanking(torch.nn.Module):
    def __init__(self):
        super(BertForPairwiseRanking, self).__init__()
        self.bert = BertModel.from_pretrained('DeepPavlov/rubert-base-cased')
        self.regressor = torch.nn.Linear(self.bert.config.hidden_size, 1)
        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, labels=None):
        # Forward pass for the first input
        outputs1 = self.bert(input_ids=input_ids1, attention_mask=attention_mask1)
        pooled_output1 = outputs1.pooler_output
        score1 = self.regressor(pooled_output1)

        # Forward pass for the second input
        outputs2 = self.bert(input_ids=input_ids2, attention_mask=attention_mask2)
        pooled_output2 = outputs2.pooler_output
        score2 = self.regressor(pooled_output2)

        # Pairwise ranking loss
        if labels is not None:
            # Ranking loss: Margin-based hinge loss
            margin = 1.0
            ranking_loss = torch.nn.MarginRankingLoss(margin=margin)
            target = 2 * labels - 1  # Convert 0/1 label to -1/1 for ranking loss
            loss = ranking_loss(score1.view(-1), score2.view(-1), target)
            return loss, score1, score2
        return score1, score2

# Initialize model, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForPairwiseRanking().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [72]:
# Function to calculate Accuracy@K
def accuracy_at_k(labels, predictions, k=5):
    top_k_indices = np.argsort(predictions)[-k:]
    top_k_labels = labels[top_k_indices]
    return np.mean(top_k_labels)

# Training with validation
num_epochs = 3
model.train()
for epoch in range(num_epochs):
    total_train_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch'):
        optimizer.zero_grad()
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        labels = batch['labels'].to(device)

        loss, _, _ = model(
            input_ids1=input_ids1,
            attention_mask1=attention_mask1,
            input_ids2=input_ids2,
            attention_mask2=attention_mask2,
            labels=labels
        )
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    val_predictions = []
    val_labels = []
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids1 = batch['input_ids1'].to(device)
            attention_mask1 = batch['attention_mask1'].to(device)
            input_ids2 = batch['input_ids2'].to(device)
            attention_mask2 = batch['attention_mask2'].to(device)
            labels = batch['labels'].to(device)

            loss, score1, score2 = model(
                input_ids1=input_ids1,
                attention_mask1=attention_mask1,
                input_ids2=input_ids2,
                attention_mask2=attention_mask2,
                labels=labels
            )
            val_loss += loss.item()
            val_predictions.extend(score1.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_accuracy = accuracy_score(val_labels, (np.array(val_predictions) > 0.5).astype(int))
    acc_at_k = accuracy_at_k(np.array(val_labels), np.array(val_predictions), k=5)
    
    print(f'Epoch {epoch+1} completed. Train Loss: {total_train_loss / len(train_dataloader):.4f}, Validation Accuracy: {val_accuracy:.4f}, Accuracy@5: {acc_at_k:.4f}')

    torch.save(model.state_dict(), f'models_checkpoints/{epoch}/pairwise_poem_rating_model_filtered.pth')
    tokenizer.save_pretrained(f'models_checkpoints/{epoch}/pairwise_poem_rating_model_filtered')


Epoch 1/3:   0%|          | 0/48035348 [01:22<?, ?batch/s]


MemoryError: 