In [2]:
import torch
import os
import re
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import Counter
import nltk

# nltk.download('punkt')
# nltk.download('stopwords')
def tokenize(text):
    text = re.sub(r'<.*?>', '', text)  # 去除HTML标签
    text = re.sub(r'[^a-zA-Z0-9]', ' ', text)  # 只保留字母和数字
    text = text.lower().split()
    return text

class MyDataset(Dataset):
    def __init__(self, train=True, max_len=500):
        self.train_data_path= r"./aclImdb/train"
        self.test_data_path  = r"./aclImdb/test"
        data_path = self.train_data_path if train else self.test_data_path
        # 所有文件名放入列表
        temp_data_path = [os.path.join(data_path, "pos"), os.path.join(data_path, "neg")]
        self.total_file_path = []  # 所有评论文件路径
        self.labels = []  # 标签列表
        self.max_len = max_len  # 最大序列长度
        for path in temp_data_path:
            file_name_list = os.listdir(path)
            file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith(".txt")]
            self.total_file_path.extend(file_path_list)
            self.labels.extend([1 if "pos" in path else 0] * len(file_name_list))
        
        # 构建词汇表
        self.vocab = self.build_vocab()
        
    def build_vocab(self, max_vocab_size=20000):
        counter = Counter()
        for file_path in self.total_file_path:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
                tokens = tokenize(text)
                counter.update(tokens)
        vocab = {word: idx + 2 for idx, (word, _) in enumerate(counter.most_common(max_vocab_size))}
        vocab['<PAD>'] = 0
        vocab['<UNK>'] = 1
        return vocab
    
    def text_to_sequence(self, text):
        tokens = tokenize(text)
        sequence = [self.vocab.get(word, self.vocab['<UNK>']) for word in tokens]
        if len(sequence) < self.max_len:
            sequence += [self.vocab['<PAD>']] * (self.max_len - len(sequence))
        else:
            sequence = sequence[:self.max_len]
        return sequence
    
    def __getitem__(self, index):
        file_path = self.total_file_path[index]
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        sequence = self.text_to_sequence(text)
        label = self.labels[index]
        return torch.tensor(sequence, dtype=torch.long), torch.tensor(label, dtype=torch.float)  # 修改标签为浮点数类型
        
    def __len__(self):
        return len(self.total_file_path)

def get_dataloader(train=True, batch_size=32):
    dataset = MyDataset(train)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
        
        self.fc1 = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        embedded = self.dropout(self.embedding(text))
        lstm_out, (hidden, cell) = self.lstm(embedded)
        
        if self.lstm.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
        
        output = self.fc1(hidden)
        
        return output


def train(model, iterator, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    
    for batch in iterator:
        optimizer.zero_grad()
        text, labels = batch
        text, labels = text.to(device), labels.to(device)
        predictions = model(text).squeeze(1)
        loss = criterion(predictions, labels)
        acc = binary_accuracy(predictions, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        print(epoch + 1, loss.item())
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion, device):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    
    with torch.no_grad():
        for batch in iterator:
            text, labels = batch
            text, labels = text.to(device), labels.to(device)
            predictions = model(text).squeeze(1)
            loss = criterion(predictions, labels)
            acc = binary_accuracy(predictions, labels)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

# 检查是否有可用的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
INPUT_DIM = len(MyDataset().vocab)
EMBEDDING_DIM = 200
HIDDEN_DIM = 128
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.6
BATCH_SIZE = 32
N_EPOCHS = 5

# 模型实例化
model = LSTMModel(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT).to(device)

# 优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=3e-5)
criterion = nn.BCEWithLogitsLoss().to(device)

# 数据加载器
train_loader = get_dataloader(train=True, batch_size=BATCH_SIZE)
test_loader = get_dataloader(train=False, batch_size=BATCH_SIZE)

# 训练模型
for epoch in range(N_EPOCHS):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    input()
    print(f'\tTest Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
    
    print(f'Epoch: {epoch+1}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')

1 0.7078604698181152
1 0.687791109085083
1 0.6932042837142944
1 0.6783578395843506
1 0.6944519281387329
1 0.676552414894104
1 0.7204541563987732
1 0.7437672019004822
1 0.7237001061439514
1 0.6574144959449768
1 0.7037581205368042
1 0.6962838172912598
1 0.6889625787734985
1 0.6899944543838501
1 0.6919846534729004
1 0.69655442237854
1 0.6990199685096741
1 0.6984058618545532
1 0.6851224899291992
1 0.6837651133537292
1 0.6762213706970215
1 0.6820398569107056
1 0.7307394742965698
1 0.7142940163612366
1 0.696173906326294
1 0.6720126867294312
1 0.69331955909729
1 0.7029737234115601
1 0.7223654985427856
1 0.7060472965240479
1 0.6953036785125732
1 0.681267261505127
1 0.7121309041976929
1 0.6853231191635132
1 0.7030692100524902
1 0.6884393692016602
1 0.691295325756073
1 0.7129917144775391
1 0.6890244483947754
1 0.6989738941192627
1 0.6861132383346558
1 0.708427906036377
1 0.6850594878196716
1 0.6902672052383423
1 0.6895478367805481
1 0.7137308716773987
1 0.7085322737693787
1 0.7182594537734985
1 