In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


DATA_DIR = '/home/jupyter/datasphere/project/MovieReview/Users/daniil/Desktop/University/AI/GreenAtom/aclImdb'


class IMDBDataset(Dataset):
    def __init__(self, texts, ratings, sentiments, tokenizer, max_length):
        self.texts = texts
        self.ratings = ratings
        self.sentiments = sentiments
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        rating = self.ratings[idx]
        sentiment = self.sentiments[idx]


        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'rating': torch.tensor(rating, dtype=torch.float),
            'sentiment': torch.tensor(sentiment, dtype=torch.long),
        }


def load_data(data_dir, split):
    texts = []
    ratings = []
    sentiments = []

    for label in ['pos', 'neg']:
        sentiment = 1 if label == 'pos' else 0
        dir_path = os.path.join(data_dir, split, label)
        for filename in os.listdir(dir_path):
            if filename.endswith('.txt') and not filename.startswith('._'):

                rating = int(filename.split('_')[1].split('.')[0])

                if rating == 5 or rating == 6:
                    continue
                with open(os.path.join(dir_path, filename), 'r', encoding='utf-8') as f:
                    text = f.read()
                    texts.append(text)
                    ratings.append(rating)
                    sentiments.append(sentiment)
    return texts, ratings, sentiments


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LENGTH = 128  
BATCH_SIZE = 16  

train_texts, train_ratings, train_sentiments = load_data(DATA_DIR, 'train')
test_texts, test_ratings, test_sentiments = load_data(DATA_DIR, 'test')

train_dataset = IMDBDataset(train_texts, train_ratings, train_sentiments, tokenizer, MAX_LENGTH)
test_dataset = IMDBDataset(test_texts, test_ratings, test_sentiments, tokenizer, MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)


class SentimentRatingModel(nn.Module):
    def __init__(self):
        super(SentimentRatingModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(p=0.3)

        self.sentiment_classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 2)
        )

        self.rating_regressor = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output 
        pooled_output = self.dropout(pooled_output)

        sentiment_logits = self.sentiment_classifier(pooled_output)

        rating_output = self.rating_regressor(pooled_output).squeeze(-1)
        return sentiment_logits, rating_output

if __name__ == "__main__":

    model = SentimentRatingModel()
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    criterion_sentiment = nn.CrossEntropyLoss()
    criterion_rating = nn.MSELoss()


    EPOCHS = 5  
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
    )


    scaler = torch.cuda.amp.GradScaler()


    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Эпоха {epoch + 1}/{EPOCHS}')
        for batch in progress_bar:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            sentiments = batch['sentiment'].to(device, non_blocking=True)
            ratings = batch['rating'].to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                sentiment_logits, rating_output = model(input_ids, attention_mask)
                loss_sentiment = criterion_sentiment(sentiment_logits, sentiments)
                loss_rating = criterion_rating(rating_output, ratings)
                loss = loss_sentiment + loss_rating

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'Потеря': f'{loss.item():.4f}'})

        avg_loss = total_loss / len(train_loader)
        print(f'Средняя потеря за эпоху: {avg_loss:.4f}')


        torch.save(model.state_dict(), f'sentiment_rating_model_epoch_FORM1_{epoch + 1}.pth')
        print(f'Модель сохранена после эпохи {epoch + 1}')


    model.eval()
    correct_sentiment = 0
    total_sentiment = 0
    rating_errors = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Оценка'):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            sentiments = batch['sentiment'].to(device, non_blocking=True)
            ratings = batch['rating'].to(device, non_blocking=True)

            sentiment_logits, rating_output = model(input_ids, attention_mask)

            _, sentiment_preds = torch.max(sentiment_logits, dim=1)

            correct_sentiment += (sentiment_preds == sentiments).sum().item()
            total_sentiment += sentiments.size(0)


            rating_errors.extend(torch.abs(rating_output - ratings).cpu().numpy())

    sentiment_accuracy = correct_sentiment / total_sentiment * 100
    rating_mae = sum(rating_errors) / len(rating_errors)

    print(f'Точность тональности: {sentiment_accuracy:.2f}%')
    print(f'Средняя абсолютная ошибка рейтинга: {rating_mae:.2f}')


    torch.save(model.state_dict(), 'sentiment_rating_model_final.pth')
    print("Финальная модель успешно сохранена.")


Эпоха 1/5: 100%|██████████| 1563/1563 [02:05<00:00, 12.49it/s, Потеря=1.2772] 


Средняя потеря за эпоху: 11.0779
Модель сохранена после эпохи 1


Эпоха 2/5: 100%|██████████| 1563/1563 [02:07<00:00, 12.28it/s, Потеря=9.3936] 


Средняя потеря за эпоху: 3.4132
Модель сохранена после эпохи 2


Эпоха 3/5: 100%|██████████| 1563/1563 [02:07<00:00, 12.26it/s, Потеря=1.3217] 


Средняя потеря за эпоху: 2.0853
Модель сохранена после эпохи 3


Эпоха 4/5: 100%|██████████| 1563/1563 [02:06<00:00, 12.35it/s, Потеря=1.1323]


Средняя потеря за эпоху: 1.4769
Модель сохранена после эпохи 4


Эпоха 5/5: 100%|██████████| 1563/1563 [02:03<00:00, 12.64it/s, Потеря=0.4016]


Средняя потеря за эпоху: 1.1727
Модель сохранена после эпохи 5


Оценка: 100%|██████████| 1563/1563 [01:06<00:00, 23.48it/s]


Точность тональности: 89.03%
Средняя абсолютная ошибка рейтинга: 1.28
Финальная модель успешно сохранена.
